|
14 | 14 |
|
15 | 15 | from __future__ import absolute_import
|
16 | 16 |
|
17 |
| -from typing import Dict, List, Optional, Union |
| 17 | +from functools import lru_cache |
| 18 | +from typing import Dict, List, Optional, Union, Any |
| 19 | +import pandas as pd |
18 | 20 | from botocore.exceptions import ClientError
|
19 | 21 |
|
20 | 22 | from sagemaker import payloads
|
|
36 | 38 | get_init_kwargs,
|
37 | 39 | get_register_kwargs,
|
38 | 40 | )
|
39 |
| -from sagemaker.jumpstart.types import JumpStartSerializablePayload |
| 41 | +from sagemaker.jumpstart.types import ( |
| 42 | + JumpStartSerializablePayload, |
| 43 | + DeploymentConfigMetadata, |
| 44 | + JumpStartBenchmarkStat, |
| 45 | + JumpStartMetadataConfig, |
| 46 | +) |
40 | 47 | from sagemaker.jumpstart.utils import (
|
41 | 48 | validate_model_id_and_get_type,
|
42 | 49 | verify_model_region_and_return_specs,
|
| 50 | + get_jumpstart_configs, |
| 51 | + extract_metrics_from_deployment_configs, |
43 | 52 | )
|
44 | 53 | from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
|
45 | 54 | from sagemaker.jumpstart.enums import JumpStartModelType
|
46 |
| -from sagemaker.utils import stringify_object, format_tags, Tags |
| 55 | +from sagemaker.utils import stringify_object, format_tags, Tags, get_instance_rate_per_hour |
47 | 56 | from sagemaker.model import (
|
48 | 57 | Model,
|
49 | 58 | ModelPackage,
|
@@ -352,6 +361,18 @@ def _validate_model_id_and_type():
|
352 | 361 | self.model_package_arn = model_init_kwargs.model_package_arn
|
353 | 362 | self.init_kwargs = model_init_kwargs.to_kwargs_dict(False)
|
354 | 363 |
|
| 364 | + metadata_configs = get_jumpstart_configs( |
| 365 | + region=self.region, |
| 366 | + model_id=self.model_id, |
| 367 | + model_version=self.model_version, |
| 368 | + sagemaker_session=self.sagemaker_session, |
| 369 | + model_type=self.model_type, |
| 370 | + ) |
| 371 | + self._deployment_configs = [ |
| 372 | + self._convert_to_deployment_config_metadata(config_name, config) |
| 373 | + for config_name, config in metadata_configs.items() |
| 374 | + ] |
| 375 | + |
355 | 376 | def log_subscription_warning(self) -> None:
|
356 | 377 | """Log message prompting the customer to subscribe to the proprietary model."""
|
357 | 378 | subscription_link = verify_model_region_and_return_specs(
|
@@ -420,6 +441,27 @@ def set_deployment_config(self, config_name: Optional[str]) -> None:
|
420 | 441 | model_id=self.model_id, model_version=self.model_version, config_name=config_name
|
421 | 442 | )
|
422 | 443 |
|
| 444 | + @property |
| 445 | + def benchmark_metrics(self) -> pd.DataFrame: |
| 446 | + """Benchmark Metrics for deployment configs |
| 447 | +
|
| 448 | + Returns: |
| 449 | + Metrics: Pandas DataFrame object. |
| 450 | + """ |
| 451 | + return pd.DataFrame(self._get_benchmark_data(self.config_name)) |
| 452 | + |
| 453 | + def display_benchmark_metrics(self) -> None: |
| 454 | + """Display Benchmark Metrics for deployment configs.""" |
| 455 | + print(self.benchmark_metrics.to_markdown()) |
| 456 | + |
| 457 | + def list_deployment_configs(self) -> List[Dict[str, Any]]: |
| 458 | + """List deployment configs for ``This`` model. |
| 459 | +
|
| 460 | + Returns: |
| 461 | + List[Dict[str, Any]]: A list of deployment configs. |
| 462 | + """ |
| 463 | + return self._deployment_configs |
| 464 | + |
423 | 465 | def _create_sagemaker_model(
|
424 | 466 | self,
|
425 | 467 | instance_type=None,
|
@@ -808,6 +850,67 @@ def register_deploy_wrapper(*args, **kwargs):
|
808 | 850 |
|
809 | 851 | return model_package
|
810 | 852 |
|
| 853 | + @lru_cache |
| 854 | + def _get_benchmark_data(self, config_name: str) -> Dict[str, List[str]]: |
| 855 | + """Constructs deployment configs benchmark data. |
| 856 | +
|
| 857 | + Args: |
| 858 | + config_name (str): The name of the selected deployment config. |
| 859 | + Returns: |
| 860 | + Dict[str, List[str]]: Deployment config benchmark data. |
| 861 | + """ |
| 862 | + return extract_metrics_from_deployment_configs( |
| 863 | + self._deployment_configs, |
| 864 | + config_name, |
| 865 | + ) |
| 866 | + |
| 867 | + def _convert_to_deployment_config_metadata( |
| 868 | + self, config_name: str, metadata_config: JumpStartMetadataConfig |
| 869 | + ) -> Dict[str, Any]: |
| 870 | + """Retrieve deployment config for config name. |
| 871 | +
|
| 872 | + Args: |
| 873 | + config_name (str): Name of deployment config. |
| 874 | + metadata_config (JumpStartMetadataConfig): Metadata config for deployment config. |
| 875 | + Returns: |
| 876 | + A deployment metadata config for config name (dict[str, Any]). |
| 877 | + """ |
| 878 | + default_inference_instance_type = metadata_config.resolved_config.get( |
| 879 | + "default_inference_instance_type" |
| 880 | + ) |
| 881 | + |
| 882 | + instance_rate = get_instance_rate_per_hour( |
| 883 | + instance_type=default_inference_instance_type, region=self.region |
| 884 | + ) |
| 885 | + |
| 886 | + benchmark_metrics = ( |
| 887 | + metadata_config.benchmark_metrics.get(default_inference_instance_type) |
| 888 | + if metadata_config.benchmark_metrics is not None |
| 889 | + else None |
| 890 | + ) |
| 891 | + if instance_rate is not None: |
| 892 | + if benchmark_metrics is not None: |
| 893 | + benchmark_metrics.append(JumpStartBenchmarkStat(instance_rate)) |
| 894 | + else: |
| 895 | + benchmark_metrics = [JumpStartBenchmarkStat(instance_rate)] |
| 896 | + |
| 897 | + init_kwargs = get_init_kwargs( |
| 898 | + model_id=self.model_id, |
| 899 | + instance_type=default_inference_instance_type, |
| 900 | + sagemaker_session=self.sagemaker_session, |
| 901 | + ) |
| 902 | + deploy_kwargs = get_deploy_kwargs( |
| 903 | + model_id=self.model_id, |
| 904 | + instance_type=default_inference_instance_type, |
| 905 | + sagemaker_session=self.sagemaker_session, |
| 906 | + ) |
| 907 | + |
| 908 | + deployment_config_metadata = DeploymentConfigMetadata( |
| 909 | + config_name, benchmark_metrics, init_kwargs, deploy_kwargs |
| 910 | + ) |
| 911 | + |
| 912 | + return deployment_config_metadata.to_json() |
| 913 | + |
811 | 914 | def __str__(self) -> str:
|
812 | 915 | """Overriding str(*) method to make more human-readable."""
|
813 | 916 | return stringify_object(self)
|
0 commit comments