|
48 | 48 | validate_model_id_and_get_type,
|
49 | 49 | verify_model_region_and_return_specs,
|
50 | 50 | get_jumpstart_configs,
|
51 |
| - extract_metrics_from_deployment_configs, |
| 51 | + get_metrics_from_deployment_configs, |
52 | 52 | )
|
53 | 53 | from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
|
54 | 54 | from sagemaker.jumpstart.enums import JumpStartModelType
|
@@ -868,7 +868,7 @@ def _get_benchmarks_data(self, config_name: str) -> Dict[str, List[str]]:
|
868 | 868 | Returns:
|
869 | 869 | Dict[str, List[str]]: Deployment config benchmark data.
|
870 | 870 | """
|
871 |
| - return extract_metrics_from_deployment_configs( |
| 871 | + return get_metrics_from_deployment_configs( |
872 | 872 | self._deployment_configs,
|
873 | 873 | config_name,
|
874 | 874 | )
|
@@ -905,20 +905,29 @@ def _convert_to_deployment_config_metadata(
|
905 | 905 | "default_inference_instance_type"
|
906 | 906 | )
|
907 | 907 |
|
908 |
| - instance_rate = get_instance_rate_per_hour( |
909 |
| - instance_type=default_inference_instance_type, region=self.region |
910 |
| - ) |
911 |
| - |
912 | 908 | benchmark_metrics = (
|
913 | 909 | metadata_config.benchmark_metrics.get(default_inference_instance_type)
|
914 | 910 | if metadata_config.benchmark_metrics is not None
|
915 | 911 | else None
|
916 | 912 | )
|
917 |
| - if instance_rate is not None: |
918 |
| - if benchmark_metrics is not None: |
919 |
| - benchmark_metrics.append(JumpStartBenchmarkStat(instance_rate)) |
| 913 | + |
| 914 | + should_fetch_instance_rate_metric = True |
| 915 | + if benchmark_metrics is not None: |
| 916 | + for benchmark_metric in benchmark_metrics: |
| 917 | + if benchmark_metric.name.lower() == "instance rate": |
| 918 | + should_fetch_instance_rate_metric = False |
| 919 | + break |
| 920 | + |
| 921 | + if should_fetch_instance_rate_metric: |
| 922 | + instance_rate = get_instance_rate_per_hour( |
| 923 | + instance_type=default_inference_instance_type, region=self.region |
| 924 | + ) |
| 925 | + instance_rate_metric = JumpStartBenchmarkStat(instance_rate) |
| 926 | + |
| 927 | + if benchmark_metrics is None: |
| 928 | + benchmark_metrics = [instance_rate_metric] |
920 | 929 | else:
|
921 |
| - benchmark_metrics = [JumpStartBenchmarkStat(instance_rate)] |
| 930 | + benchmark_metrics.append(instance_rate_metric) |
922 | 931 |
|
923 | 932 | init_kwargs = get_init_kwargs(
|
924 | 933 | model_id=self.model_id,
|
|
0 commit comments