|
13 | 13 | """This module stores JumpStart implementation of Model class."""
|
14 | 14 |
|
15 | 15 | from __future__ import absolute_import
|
16 |
| -import re |
17 | 16 |
|
18 | 17 | from typing import Dict, List, Optional, Union
|
19 | 18 | from sagemaker import payloads
|
|
28 | 27 | get_default_predictor,
|
29 | 28 | get_deploy_kwargs,
|
30 | 29 | get_init_kwargs,
|
| 30 | + get_register_kwargs, |
31 | 31 | )
|
32 | 32 | from sagemaker.jumpstart.types import JumpStartSerializablePayload
|
33 | 33 | from sagemaker.jumpstart.utils import is_valid_model_id
|
34 | 34 | from sagemaker.utils import stringify_object
|
35 |
| -from sagemaker.model import MODEL_PACKAGE_ARN_PATTERN, Model |
| 35 | +from sagemaker.model import ( |
| 36 | + Model, |
| 37 | + ModelPackage, |
| 38 | +) |
36 | 39 | from sagemaker.model_monitor.data_capture_config import DataCaptureConfig
|
37 | 40 | from sagemaker.predictor import PredictorBase
|
38 | 41 | from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
|
39 | 42 | from sagemaker.session import Session
|
40 | 43 | from sagemaker.workflow.entities import PipelineVariable
|
| 44 | +from sagemaker.model_metrics import ModelMetrics |
| 45 | +from sagemaker.metadata_properties import MetadataProperties |
| 46 | +from sagemaker.drift_check_baselines import DriftCheckBaselines |
41 | 47 |
|
42 | 48 |
|
43 | 49 | class JumpStartModel(Model):
|
@@ -309,11 +315,12 @@ def _is_valid_model_id_hook():
|
309 | 315 | self.tolerate_vulnerable_model = model_init_kwargs.tolerate_vulnerable_model
|
310 | 316 | self.tolerate_deprecated_model = model_init_kwargs.tolerate_deprecated_model
|
311 | 317 | self.region = model_init_kwargs.region
|
312 |
| - self.model_package_arn = model_init_kwargs.model_package_arn |
313 | 318 | self.sagemaker_session = model_init_kwargs.sagemaker_session
|
314 | 319 |
|
315 | 320 | super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())
|
316 | 321 |
|
| 322 | + self.model_package_arn = model_init_kwargs.model_package_arn |
| 323 | + |
317 | 324 | def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]:
|
318 | 325 | """Returns all example payloads associated with the model.
|
319 | 326 |
|
@@ -390,30 +397,29 @@ def _create_sagemaker_model(
|
390 | 397 | # inference endpoint.
|
391 | 398 | if self.model_package_arn and not self._model_data_is_set:
|
392 | 399 | # When a ModelPackageArn is provided we just create the Model
|
393 |
| - match = re.match(MODEL_PACKAGE_ARN_PATTERN, self.model_package_arn) |
394 |
| - if match: |
395 |
| - model_package_name = match.group(3) |
396 |
| - else: |
397 |
| - # model_package_arn can be just the name if your account owns the Model Package |
398 |
| - model_package_name = self.model_package_arn |
399 |
| - container_def = {"ModelPackageName": self.model_package_arn} |
400 |
| - |
401 |
| - if self.env != {}: |
402 |
| - container_def["Environment"] = self.env |
403 |
| - |
404 |
| - if self.name is None: |
405 |
| - self._base_name = model_package_name |
406 |
| - |
407 |
| - self._set_model_name_if_needed() |
408 |
| - |
409 |
| - self.sagemaker_session.create_model( |
410 |
| - self.name, |
411 |
| - self.role, |
412 |
| - container_def, |
| 400 | + model_package = ModelPackage( |
| 401 | + role=self.role, |
| 402 | + model_data=self.model_data, |
| 403 | + model_package_arn=self.model_package_arn, |
| 404 | + sagemaker_session=self.sagemaker_session, |
| 405 | + predictor_cls=self.predictor_cls, |
413 | 406 | vpc_config=self.vpc_config,
|
414 |
| - enable_network_isolation=self.enable_network_isolation(), |
| 407 | + ) |
| 408 | + if self.name is not None: |
| 409 | + model_package.name = self.name |
| 410 | + if self.env is not None: |
| 411 | + model_package.env = self.env |
| 412 | + model_package._create_sagemaker_model( |
| 413 | + instance_type=instance_type, |
| 414 | + accelerator_type=accelerator_type, |
415 | 415 | tags=tags,
|
| 416 | + serverless_inference_config=serverless_inference_config, |
| 417 | + **kwargs, |
416 | 418 | )
|
| 419 | + if self._base_name is None and model_package._base_name is not None: |
| 420 | + self._base_name = model_package._base_name |
| 421 | + if self.name is None and model_package.name is not None: |
| 422 | + self.name = model_package.name |
417 | 423 | else:
|
418 | 424 | super(JumpStartModel, self)._create_sagemaker_model(
|
419 | 425 | instance_type=instance_type,
|
@@ -565,6 +571,124 @@ def deploy(
|
565 | 571 | # If a predictor class was passed, do not mutate predictor
|
566 | 572 | return predictor
|
567 | 573 |
|
| 574 | + def register( |
| 575 | + self, |
| 576 | + content_types: List[Union[str, PipelineVariable]] = None, |
| 577 | + response_types: List[Union[str, PipelineVariable]] = None, |
| 578 | + inference_instances: Optional[List[Union[str, PipelineVariable]]] = None, |
| 579 | + transform_instances: Optional[List[Union[str, PipelineVariable]]] = None, |
| 580 | + model_package_group_name: Optional[Union[str, PipelineVariable]] = None, |
| 581 | + image_uri: Optional[Union[str, PipelineVariable]] = None, |
| 582 | + model_metrics: Optional[ModelMetrics] = None, |
| 583 | + metadata_properties: Optional[MetadataProperties] = None, |
| 584 | + approval_status: Optional[Union[str, PipelineVariable]] = None, |
| 585 | + description: Optional[str] = None, |
| 586 | + drift_check_baselines: Optional[DriftCheckBaselines] = None, |
| 587 | + customer_metadata_properties: Optional[Dict[str, Union[str, PipelineVariable]]] = None, |
| 588 | + validation_specification: Optional[Union[str, PipelineVariable]] = None, |
| 589 | + domain: Optional[Union[str, PipelineVariable]] = None, |
| 590 | + task: Optional[Union[str, PipelineVariable]] = None, |
| 591 | + sample_payload_url: Optional[Union[str, PipelineVariable]] = None, |
| 592 | + framework: Optional[Union[str, PipelineVariable]] = None, |
| 593 | + framework_version: Optional[Union[str, PipelineVariable]] = None, |
| 594 | + nearest_model_name: Optional[Union[str, PipelineVariable]] = None, |
| 595 | + data_input_configuration: Optional[Union[str, PipelineVariable]] = None, |
| 596 | + skip_model_validation: Optional[Union[str, PipelineVariable]] = None, |
| 597 | + ): |
| 598 | + """Creates a model package for creating SageMaker models or listing on Marketplace. |
| 599 | +
|
| 600 | + Args: |
| 601 | + content_types (list[str] or list[PipelineVariable]): The supported MIME types |
| 602 | + for the input data. |
| 603 | + response_types (list[str] or list[PipelineVariable]): The supported MIME types |
| 604 | + for the output data. |
| 605 | + inference_instances (list[str] or list[PipelineVariable]): A list of the instance |
| 606 | + types that are used to generate inferences in real-time (default: None). |
| 607 | + transform_instances (list[str] or list[PipelineVariable]): A list of the instance types |
| 608 | + on which a transformation job can be run or on which an endpoint can be deployed |
| 609 | + (default: None). |
| 610 | + model_package_group_name (str or PipelineVariable): Model Package Group name, |
| 611 | + exclusive to `model_package_name`, using `model_package_group_name` makes the |
| 612 | + Model Package versioned. Defaults to ``None``. |
| 613 | + image_uri (str or PipelineVariable): Inference image URI for the container. Model class' |
| 614 | + self.image will be used if it is None. Defaults to ``None``. |
| 615 | + model_metrics (ModelMetrics): ModelMetrics object. Defaults to ``None``. |
| 616 | + metadata_properties (MetadataProperties): MetadataProperties object. |
| 617 | + Defaults to ``None``. |
| 618 | + approval_status (str or PipelineVariable): Model Approval Status, values can be |
| 619 | + "Approved", "Rejected", or "PendingManualApproval". Defaults to |
| 620 | + ``PendingManualApproval``. |
| 621 | + description (str): Model Package description. Defaults to ``None``. |
| 622 | + drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None). |
| 623 | + customer_metadata_properties (dict[str, str] or dict[str, PipelineVariable]): |
| 624 | + A dictionary of key-value paired metadata properties (default: None). |
| 625 | + domain (str or PipelineVariable): Domain values can be "COMPUTER_VISION", |
| 626 | + "NATURAL_LANGUAGE_PROCESSING", "MACHINE_LEARNING" (default: None). |
| 627 | + sample_payload_url (str or PipelineVariable): The S3 path where the sample payload |
| 628 | + is stored (default: None). |
| 629 | + task (str or PipelineVariable): Task values which are supported by Inference Recommender |
| 630 | + are "FILL_MASK", "IMAGE_CLASSIFICATION", "OBJECT_DETECTION", "TEXT_GENERATION", |
| 631 | + "IMAGE_SEGMENTATION", "CLASSIFICATION", "REGRESSION", "OTHER" (default: None). |
| 632 | + framework (str or PipelineVariable): Machine learning framework of the model package |
| 633 | + container image (default: None). |
| 634 | + framework_version (str or PipelineVariable): Framework version of the Model Package |
| 635 | + Container Image (default: None). |
| 636 | + nearest_model_name (str or PipelineVariable): Name of a pre-trained machine learning |
| 637 | + benchmarked by Amazon SageMaker Inference Recommender (default: None). |
| 638 | + data_input_configuration (str or PipelineVariable): Input object for the model |
| 639 | + (default: None). |
| 640 | + skip_model_validation (str or PipelineVariable): Indicates if you want to skip model |
| 641 | + validation. Values can be "All" or "None" (default: None). |
| 642 | +
|
| 643 | + Returns: |
| 644 | + A `sagemaker.model.ModelPackage` instance. |
| 645 | + """ |
| 646 | + |
| 647 | + register_kwargs = get_register_kwargs( |
| 648 | + model_id=self.model_id, |
| 649 | + model_version=self.model_version, |
| 650 | + region=self.region, |
| 651 | + tolerate_deprecated_model=self.tolerate_deprecated_model, |
| 652 | + tolerate_vulnerable_model=self.tolerate_vulnerable_model, |
| 653 | + sagemaker_session=self.sagemaker_session, |
| 654 | + supported_content_types=content_types, |
| 655 | + response_types=response_types, |
| 656 | + inference_instances=inference_instances, |
| 657 | + transform_instances=transform_instances, |
| 658 | + model_package_group_name=model_package_group_name, |
| 659 | + image_uri=image_uri, |
| 660 | + model_metrics=model_metrics, |
| 661 | + metadata_properties=metadata_properties, |
| 662 | + approval_status=approval_status, |
| 663 | + description=description, |
| 664 | + drift_check_baselines=drift_check_baselines, |
| 665 | + customer_metadata_properties=customer_metadata_properties, |
| 666 | + validation_specification=validation_specification, |
| 667 | + domain=domain, |
| 668 | + task=task, |
| 669 | + sample_payload_url=sample_payload_url, |
| 670 | + framework=framework, |
| 671 | + framework_version=framework_version, |
| 672 | + nearest_model_name=nearest_model_name, |
| 673 | + data_input_configuration=data_input_configuration, |
| 674 | + skip_model_validation=skip_model_validation, |
| 675 | + ) |
| 676 | + |
| 677 | + model_package = super(JumpStartModel, self).register(**register_kwargs.to_kwargs_dict()) |
| 678 | + |
| 679 | + def register_deploy_wrapper(*args, **kwargs): |
| 680 | + if self.model_package_arn is not None: |
| 681 | + return self.deploy(*args, **kwargs) |
| 682 | + else: |
| 683 | + self.model_package_arn = model_package.model_package_arn |
| 684 | + predictor = self.deploy(*args, **kwargs) |
| 685 | + self.model_package_arn = None |
| 686 | + return predictor |
| 687 | + |
| 688 | + model_package.deploy = register_deploy_wrapper |
| 689 | + |
| 690 | + return model_package |
| 691 | + |
568 | 692 | def __str__(self) -> str:
|
569 | 693 | """Overriding str(*) method to make more human-readable."""
|
570 | 694 | return stringify_object(self)
|
0 commit comments