Skip to content

Commit be9bf89

Browse files
author
Jonathan Makunga
committed
Add Unit Tests
1 parent 432f928 commit be9bf89

File tree

5 files changed

+303
-19
lines changed

5 files changed

+303
-19
lines changed

src/sagemaker/jumpstart/model.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import absolute_import
1616

17-
from functools import lru_cache
17+
from functools import cached_property
1818
from typing import Dict, List, Optional, Union, Any
1919
import pandas as pd
2020
from botocore.exceptions import ClientError
@@ -293,17 +293,6 @@ def __init__(
293293
ValueError: If the model ID is not recognized by JumpStart.
294294
"""
295295

296-
metadata_configs = get_jumpstart_configs(
297-
region=region,
298-
model_id=model_id,
299-
model_version=model_version,
300-
sagemaker_session=sagemaker_session,
301-
)
302-
self._deployment_configs = self._deployment_configs = [
303-
self._convert_to_deployment_config_metadata(config_name, config)
304-
for config_name, config in metadata_configs.items()
305-
]
306-
307296
def _validate_model_id_and_type():
308297
return validate_model_id_and_get_type(
309298
model_id=model_id,
@@ -372,6 +361,17 @@ def _validate_model_id_and_type():
372361
self.model_package_arn = model_init_kwargs.model_package_arn
373362
self.init_kwargs = model_init_kwargs.to_kwargs_dict(False)
374363

364+
metadata_configs = get_jumpstart_configs(
365+
region=region,
366+
model_id=model_id,
367+
model_version=model_version,
368+
sagemaker_session=sagemaker_session,
369+
)
370+
self._deployment_configs = [
371+
self._convert_to_deployment_config_metadata(config_name, config)
372+
for config_name, config in metadata_configs.items()
373+
]
374+
375375
def log_subscription_warning(self) -> None:
376376
"""Log message prompting the customer to subscribe to the proprietary model."""
377377
subscription_link = verify_model_region_and_return_specs(
@@ -828,8 +828,7 @@ def register_deploy_wrapper(*args, **kwargs):
828828

829829
return model_package
830830

831-
@lru_cache
832-
@property
831+
@cached_property
833832
def benchmark_metrics(self) -> pd.DataFrame:
834833
"""Pandas DataFrame object of Benchmark Metrics for deployment configs"""
835834
data = extract_metrics_from_deployment_configs(
@@ -844,16 +843,17 @@ def display_benchmark_metrics(self, *args, **kwargs):
844843

845844
def list_deployment_configs(self) -> List[Dict[str, Any]]:
846845
"""List deployment configs for ``This`` model in the current region.
846+
847847
Returns:
848-
A list of deployment configs (List[Dict[str, Any]]).
848+
List[Dict[str, Any]]: A list of deployment configs.
849849
"""
850850
return self._deployment_configs
851851

852-
@lru_cache
853852
def _convert_to_deployment_config_metadata(
854853
self, config_name: str, metadata_config: JumpStartMetadataConfig
855854
) -> Dict[str, Any]:
856855
"""Retrieve deployment config for config name.
856+
857857
Args:
858858
config_name (str): Name of deployment config.
859859
metadata_config (JumpStartMetadataConfig): Metadata config for deployment config.

src/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,6 +1005,7 @@ def extract_metrics_from_deployment_configs(
10051005
deployment_configs: list[dict[str, Any]], config_name: str
10061006
) -> Dict[str, List[str]]:
10071007
"""Extracts metrics from deployment configs.
1008+
10081009
Args:
10091010
deployment_configs (list[dict[str, Any]]): List of deployment configs.
10101011
config_name (str): The name of the deployment config use by the model.

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -431,14 +431,15 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
431431
sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration
432432
)
433433

434-
def display_benchmark_metrics(self):
434+
def display_benchmark_metrics(self, *args, **kwargs):
435435
"""Display Markdown Benchmark Metrics for deployment configs."""
436-
self.pysdk_model.display_benchmark_metrics()
436+
self.pysdk_model.display_benchmark_metrics(args, kwargs)
437437

438438
def list_deployment_configs(self) -> list[dict[str, Any]]:
439439
"""List deployment configs for ``This`` model in the current region.
440+
440441
Returns:
441-
A list of deployment configs (List[Dict[str, Any]]).
442+
List[Dict[str, Any]]: A list of deployment configs.
442443
"""
443444
return self.pysdk_model.list_deployment_configs()
444445

src/sagemaker/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1666,6 +1666,7 @@ def get_instance_rate_per_hour(
16661666
region: str,
16671667
) -> Union[Dict[str, str], None]:
16681668
"""Gets instance rate per hour for the given instance type.
1669+
16691670
Args:
16701671
instance_type (str): The instance type.
16711672
region (str): The region.

0 commit comments

Comments
 (0)