Skip to content

Commit 1dabf41

Browse files
makungaj1Jonathan Makunga
authored andcommitted
Benchmark feature v2 (aws#4618)
* Add funtionalities to get and set deployment config * Resolve PR comments * ModelBuilder-JS * Add Unit tests * Refactoring * Testing with Notebook * Test backward compatibility * Remove Accelerated column if all not enabled * Fix docstring * Resolved PR Review comments * Docstring * increase code coverage * Testing fix with Notebook * Only fetch instance rate metrics if not present * Increase code coverage --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 006e577 commit 1dabf41

File tree

5 files changed

+44
-14
lines changed

5 files changed

+44
-14
lines changed

src/sagemaker/jumpstart/model.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
validate_model_id_and_get_type,
4949
verify_model_region_and_return_specs,
5050
get_jumpstart_configs,
51-
extract_metrics_from_deployment_configs,
51+
get_metrics_from_deployment_configs,
5252
)
5353
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
5454
from sagemaker.jumpstart.enums import JumpStartModelType
@@ -868,7 +868,7 @@ def _get_benchmarks_data(self, config_name: str) -> Dict[str, List[str]]:
868868
Returns:
869869
Dict[str, List[str]]: Deployment config benchmark data.
870870
"""
871-
return extract_metrics_from_deployment_configs(
871+
return get_metrics_from_deployment_configs(
872872
self._deployment_configs,
873873
config_name,
874874
)
@@ -905,20 +905,29 @@ def _convert_to_deployment_config_metadata(
905905
"default_inference_instance_type"
906906
)
907907

908-
instance_rate = get_instance_rate_per_hour(
909-
instance_type=default_inference_instance_type, region=self.region
910-
)
911-
912908
benchmark_metrics = (
913909
metadata_config.benchmark_metrics.get(default_inference_instance_type)
914910
if metadata_config.benchmark_metrics is not None
915911
else None
916912
)
917-
if instance_rate is not None:
918-
if benchmark_metrics is not None:
919-
benchmark_metrics.append(JumpStartBenchmarkStat(instance_rate))
913+
914+
should_fetch_instance_rate_metric = True
915+
if benchmark_metrics is not None:
916+
for benchmark_metric in benchmark_metrics:
917+
if benchmark_metric.name.lower() == "instance rate":
918+
should_fetch_instance_rate_metric = False
919+
break
920+
921+
if should_fetch_instance_rate_metric:
922+
instance_rate = get_instance_rate_per_hour(
923+
instance_type=default_inference_instance_type, region=self.region
924+
)
925+
instance_rate_metric = JumpStartBenchmarkStat(instance_rate)
926+
927+
if benchmark_metrics is None:
928+
benchmark_metrics = [instance_rate_metric]
920929
else:
921-
benchmark_metrics = [JumpStartBenchmarkStat(instance_rate)]
930+
benchmark_metrics.append(instance_rate_metric)
922931

923932
init_kwargs = get_init_kwargs(
924933
model_id=self.model_id,

src/sagemaker/jumpstart/utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,7 @@ def get_jumpstart_configs(
10301030
)
10311031

10321032

1033-
def extract_metrics_from_deployment_configs(
1033+
def get_metrics_from_deployment_configs(
10341034
deployment_configs: List[Dict[str, Any]], config_name: str
10351035
) -> Dict[str, List[str]]:
10361036
"""Extracts metrics from deployment configs.

tests/unit/sagemaker/jumpstart/model/test_model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
get_base_spec_with_prototype_configs,
5252
get_mock_init_kwargs,
5353
get_base_deployment_configs,
54+
get_base_spec_with_prototype_configs_with_missing_benchmarks,
5455
)
5556
import boto3
5657

@@ -1790,7 +1791,7 @@ def test_model_retrieve_deployment_config(
17901791

17911792
mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)
17921793
mock_verify_model_region_and_return_specs.side_effect = (
1793-
lambda *args, **kwargs: get_base_spec_with_prototype_configs()
1794+
lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks()
17941795
)
17951796
mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: {
17961797
"name": "Instance Rate",
@@ -1838,7 +1839,7 @@ def test_model_display_benchmark_metrics(
18381839

18391840
mock_get_init_kwargs.side_effect = lambda *args, **kwargs: get_mock_init_kwargs(model_id)
18401841
mock_verify_model_region_and_return_specs.side_effect = (
1841-
lambda *args, **kwargs: get_base_spec_with_prototype_configs()
1842+
lambda *args, **kwargs: get_base_spec_with_prototype_configs_with_missing_benchmarks()
18421843
)
18431844
mock_get_instance_rate_per_hour.side_effect = lambda *args, **kwargs: {
18441845
"name": "Instance Rate",

tests/unit/sagemaker/jumpstart/test_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1810,6 +1810,6 @@ def test_get_jumpstart_benchmark_stats_training(
18101810
],
18111811
)
18121812
def test_extract_metrics_from_deployment_configs(config_name, configs, expected):
1813-
data = utils.extract_metrics_from_deployment_configs(configs, config_name)
1813+
data = utils.get_metrics_from_deployment_configs(configs, config_name)
18141814

18151815
assert data == expected

tests/unit/sagemaker/jumpstart/utils.py

+20
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,26 @@ def get_base_spec_with_prototype_configs(
226226
return JumpStartModelSpecs(spec)
227227

228228

229+
def get_base_spec_with_prototype_configs_with_missing_benchmarks(
230+
region: str = None,
231+
model_id: str = None,
232+
version: str = None,
233+
s3_client: boto3.client = None,
234+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
235+
) -> JumpStartModelSpecs:
236+
spec = copy.deepcopy(BASE_SPEC)
237+
copy_inference_configs = copy.deepcopy(INFERENCE_CONFIGS)
238+
copy_inference_configs["inference_configs"]["neuron-inference"]["benchmark_metrics"] = None
239+
240+
inference_configs = {**INFERENCE_CONFIGS, **INFERENCE_CONFIG_RANKINGS}
241+
training_configs = {**TRAINING_CONFIGS, **TRAINING_CONFIG_RANKINGS}
242+
243+
spec.update(inference_configs)
244+
spec.update(training_configs)
245+
246+
return JumpStartModelSpecs(spec)
247+
248+
229249
def get_prototype_spec_with_configs(
230250
region: str = None,
231251
model_id: str = None,

0 commit comments

Comments
 (0)