|
41 | 41 | from sagemaker.jumpstart.types import (
|
42 | 42 | JumpStartSerializablePayload,
|
43 | 43 | DeploymentConfigMetadata,
|
44 |
| - JumpStartBenchmarkStat, |
45 |
| - JumpStartMetadataConfig, |
46 | 44 | )
|
47 | 45 | from sagemaker.jumpstart.utils import (
|
48 | 46 | validate_model_id_and_get_type,
|
49 | 47 | verify_model_region_and_return_specs,
|
50 | 48 | get_jumpstart_configs,
|
51 | 49 | get_metrics_from_deployment_configs,
|
| 50 | + add_instance_rate_stats_to_benchmark_metrics, |
52 | 51 | )
|
53 | 52 | from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
|
54 | 53 | from sagemaker.jumpstart.enums import JumpStartModelType
|
55 |
| -from sagemaker.utils import stringify_object, format_tags, Tags, get_instance_rate_per_hour |
| 54 | +from sagemaker.utils import stringify_object, format_tags, Tags |
56 | 55 | from sagemaker.model import (
|
57 | 56 | Model,
|
58 | 57 | ModelPackage,
|
@@ -361,17 +360,13 @@ def _validate_model_id_and_type():
|
361 | 360 | self.model_package_arn = model_init_kwargs.model_package_arn
|
362 | 361 | self.init_kwargs = model_init_kwargs.to_kwargs_dict(False)
|
363 | 362 |
|
364 |
| - metadata_configs = get_jumpstart_configs( |
| 363 | + self._metadata_configs = get_jumpstart_configs( |
365 | 364 | region=self.region,
|
366 | 365 | model_id=self.model_id,
|
367 | 366 | model_version=self.model_version,
|
368 | 367 | sagemaker_session=self.sagemaker_session,
|
369 | 368 | model_type=self.model_type,
|
370 | 369 | )
|
371 |
| - self._deployment_configs = [ |
372 |
| - self._convert_to_deployment_config_metadata(config_name, config) |
373 |
| - for config_name, config in metadata_configs.items() |
374 |
| - ] |
375 | 370 |
|
376 | 371 | def log_subscription_warning(self) -> None:
|
377 | 372 | """Log message prompting the customer to subscribe to the proprietary model."""
|
@@ -449,33 +444,46 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None:
|
449 | 444 |
|
450 | 445 | @property
|
451 | 446 | def deployment_config(self) -> Optional[Dict[str, Any]]:
|
452 |
| - """The deployment config that will be applied to the model. |
| 447 | + """The deployment config that will be applied to ``This`` model. |
453 | 448 |
|
454 | 449 | Returns:
|
455 |
| - Optional[Dict[str, Any]]: Deployment config that will be applied to the model. |
| 450 | + Optional[Dict[str, Any]]: Deployment config. |
456 | 451 | """
|
457 |
| - return self._retrieve_selected_deployment_config(self.config_name) |
| 452 | + deployment_config = self._retrieve_selected_deployment_config( |
| 453 | + self.config_name, self.instance_type |
| 454 | + ) |
| 455 | + return deployment_config.to_json() if deployment_config is not None else None |
458 | 456 |
|
459 | 457 | @property
|
460 | 458 | def benchmark_metrics(self) -> pd.DataFrame:
|
461 |
| - """Benchmark Metrics for deployment configs |
| 459 | + """Benchmark Metrics for deployment configs. |
462 | 460 |
|
463 | 461 | Returns:
|
464 |
| - Metrics: Pandas DataFrame object. |
| 462 | + Benchmark Metrics: Pandas DataFrame object. |
465 | 463 | """
|
466 |
| - return pd.DataFrame(self._get_benchmarks_data(self.config_name)) |
| 464 | + benchmark_metrics_data = self._get_deployment_configs_benchmarks_data( |
| 465 | + self.config_name, self.instance_type |
| 466 | + ) |
| 467 | + keys = list(benchmark_metrics_data.keys()) |
| 468 | + df = pd.DataFrame(benchmark_metrics_data).sort_values(by=[keys[0], keys[1]]) |
| 469 | + return df |
467 | 470 |
|
468 | 471 | def display_benchmark_metrics(self) -> None:
|
469 |
| - """Display Benchmark Metrics for deployment configs.""" |
470 |
| - print(self.benchmark_metrics.to_markdown()) |
| 472 | + """Display deployment configs benchmark metrics.""" |
| 473 | + print(self.benchmark_metrics.to_markdown(index=False)) |
471 | 474 |
|
472 | 475 | def list_deployment_configs(self) -> List[Dict[str, Any]]:
|
473 | 476 | """List deployment configs for ``This`` model.
|
474 | 477 |
|
475 | 478 | Returns:
|
476 | 479 | List[Dict[str, Any]]: A list of deployment configs.
|
477 | 480 | """
|
478 |
| - return self._deployment_configs |
| 481 | + return [ |
| 482 | + deployment_config.to_json() |
| 483 | + for deployment_config in self._get_deployment_configs( |
| 484 | + self.config_name, self.instance_type |
| 485 | + ) |
| 486 | + ] |
479 | 487 |
|
480 | 488 | def _create_sagemaker_model(
|
481 | 489 | self,
|
@@ -866,92 +874,94 @@ def register_deploy_wrapper(*args, **kwargs):
|
866 | 874 | return model_package
|
867 | 875 |
|
868 | 876 | @lru_cache
|
869 |
| - def _get_benchmarks_data(self, config_name: str) -> Dict[str, List[str]]: |
| 877 | + def _get_deployment_configs_benchmarks_data( |
| 878 | + self, config_name: str, instance_type: str |
| 879 | + ) -> Dict[str, Any]: |
870 | 880 | """Deployment configs benchmark metrics.
|
871 | 881 |
|
872 | 882 | Args:
|
873 |
| - config_name (str): The name of the selected deployment config. |
| 883 | + config_name (str): Name of selected deployment config. |
| 884 | + instance_type (str): The selected Instance type. |
874 | 885 | Returns:
|
875 | 886 | Dict[str, List[str]]: Deployment config benchmark data.
|
876 | 887 | """
|
877 | 888 | return get_metrics_from_deployment_configs(
|
878 |
| - self._deployment_configs, |
879 |
| - config_name, |
| 889 | + self._get_deployment_configs(config_name, instance_type) |
880 | 890 | )
|
881 | 891 |
|
882 | 892 | @lru_cache
|
883 |
| - def _retrieve_selected_deployment_config(self, config_name: str) -> Optional[Dict[str, Any]]: |
884 |
| - """Retrieve the deployment config to apply to the model. |
| 893 | + def _retrieve_selected_deployment_config( |
| 894 | + self, config_name: str, instance_type: str |
| 895 | + ) -> Optional[DeploymentConfigMetadata]: |
| 896 | + """Retrieve the deployment config to apply to `This` model. |
885 | 897 |
|
886 | 898 | Args:
|
887 | 899 | config_name (str): The name of the deployment config to retrieve.
|
| 900 | + instance_type (str): The instance type of the deployment config to retrieve. |
888 | 901 | Returns:
|
889 | 902 | Optional[Dict[str, Any]]: The retrieved deployment config.
|
890 | 903 | """
|
891 | 904 | if config_name is None:
|
892 | 905 | return None
|
893 | 906 |
|
894 |
| - for deployment_config in self._deployment_configs: |
895 |
| - if deployment_config.get("DeploymentConfigName") == config_name: |
| 907 | + for deployment_config in self._get_deployment_configs(config_name, instance_type): |
| 908 | + if deployment_config.deployment_config_name == config_name: |
896 | 909 | return deployment_config
|
897 | 910 | return None
|
898 | 911 |
|
899 |
| - def _convert_to_deployment_config_metadata( |
900 |
| - self, config_name: str, metadata_config: JumpStartMetadataConfig |
901 |
| - ) -> Dict[str, Any]: |
902 |
| - """Retrieve deployment config for config name. |
| 912 | + @lru_cache |
| 913 | + def _get_deployment_configs( |
| 914 | + self, selected_config_name: str, selected_instance_type: str |
| 915 | + ) -> List[DeploymentConfigMetadata]: |
| 916 | + """Retrieve deployment configs metadata. |
903 | 917 |
|
904 | 918 | Args:
|
905 |
| - config_name (str): Name of deployment config. |
906 |
| - metadata_config (JumpStartMetadataConfig): Metadata config for deployment config. |
907 |
| - Returns: |
908 |
| - A deployment metadata config for config name (dict[str, Any]). |
| 919 | + selected_config_name (str): The name of the selected deployment config. |
| 920 | + selected_instance_type (str): The selected instance type. |
909 | 921 | """
|
910 |
| - default_inference_instance_type = metadata_config.resolved_config.get( |
911 |
| - "default_inference_instance_type" |
912 |
| - ) |
913 |
| - |
914 |
| - benchmark_metrics = ( |
915 |
| - metadata_config.benchmark_metrics.get(default_inference_instance_type) |
916 |
| - if metadata_config.benchmark_metrics is not None |
917 |
| - else None |
918 |
| - ) |
919 |
| - |
920 |
| - should_fetch_instance_rate_metric = True |
921 |
| - if benchmark_metrics is not None: |
922 |
| - for benchmark_metric in benchmark_metrics: |
923 |
| - if benchmark_metric.name.lower() == "instance rate": |
924 |
| - should_fetch_instance_rate_metric = False |
925 |
| - break |
926 |
| - |
927 |
| - if should_fetch_instance_rate_metric: |
928 |
| - instance_rate = get_instance_rate_per_hour( |
929 |
| - instance_type=default_inference_instance_type, region=self.region |
| 922 | + deployment_configs = [] |
| 923 | + if self._metadata_configs is None: |
| 924 | + return deployment_configs |
| 925 | + |
| 926 | + err = None |
| 927 | + for config_name, metadata_config in self._metadata_configs.items(): |
| 928 | + if err is None or "is not authorized to perform: pricing:GetProducts" not in err: |
| 929 | + err, metadata_config.benchmark_metrics = ( |
| 930 | + add_instance_rate_stats_to_benchmark_metrics( |
| 931 | + self.region, metadata_config.benchmark_metrics |
| 932 | + ) |
| 933 | + ) |
| 934 | + |
| 935 | + resolved_config = metadata_config.resolved_config |
| 936 | + if selected_config_name == config_name: |
| 937 | + instance_type_to_use = selected_instance_type |
| 938 | + else: |
| 939 | + instance_type_to_use = resolved_config.get("default_inference_instance_type") |
| 940 | + |
| 941 | + init_kwargs = get_init_kwargs( |
| 942 | + model_id=self.model_id, |
| 943 | + instance_type=instance_type_to_use, |
| 944 | + sagemaker_session=self.sagemaker_session, |
930 | 945 | )
|
931 |
| - if instance_rate is not None: |
932 |
| - instance_rate_metric = JumpStartBenchmarkStat(instance_rate) |
933 |
| - |
934 |
| - if benchmark_metrics is None: |
935 |
| - benchmark_metrics = [instance_rate_metric] |
936 |
| - else: |
937 |
| - benchmark_metrics.append(instance_rate_metric) |
938 |
| - |
939 |
| - init_kwargs = get_init_kwargs( |
940 |
| - model_id=self.model_id, |
941 |
| - instance_type=default_inference_instance_type, |
942 |
| - sagemaker_session=self.sagemaker_session, |
943 |
| - ) |
944 |
| - deploy_kwargs = get_deploy_kwargs( |
945 |
| - model_id=self.model_id, |
946 |
| - instance_type=default_inference_instance_type, |
947 |
| - sagemaker_session=self.sagemaker_session, |
948 |
| - ) |
| 946 | + deploy_kwargs = get_deploy_kwargs( |
| 947 | + model_id=self.model_id, |
| 948 | + instance_type=instance_type_to_use, |
| 949 | + sagemaker_session=self.sagemaker_session, |
| 950 | + ) |
| 951 | + deployment_config_metadata = DeploymentConfigMetadata( |
| 952 | + config_name, |
| 953 | + metadata_config.benchmark_metrics, |
| 954 | + resolved_config, |
| 955 | + init_kwargs, |
| 956 | + deploy_kwargs, |
| 957 | + ) |
| 958 | + deployment_configs.append(deployment_config_metadata) |
949 | 959 |
|
950 |
| - deployment_config_metadata = DeploymentConfigMetadata( |
951 |
| - config_name, benchmark_metrics, init_kwargs, deploy_kwargs |
952 |
| - ) |
| 960 | + if err is not None and "is not authorized to perform: pricing:GetProducts" in err: |
| 961 | + error_message = "Instance rate metrics will be omitted. Reason: %s" |
| 962 | + JUMPSTART_LOGGER.warning(error_message, err) |
953 | 963 |
|
954 |
| - return deployment_config_metadata.to_json() |
| 964 | + return deployment_configs |
955 | 965 |
|
956 | 966 | def __str__(self) -> str:
|
957 | 967 | """Overriding str(*) method to make more human-readable."""
|
|
0 commit comments