Skip to content

Commit 9cc4fb0

Browse files
makungaj1Jonathan Makunga
authored andcommitted
Add ReadOnly APIs (aws#4606)
* Add ReadOnly APIs * Resolving PR review comments * Resolve PR review comments * Refactoring * Refactoring * Add Caching * Refactore * Resolving conflicts * Add Unit Tests * Fix Unit Tests * Fix unit tests * Fix UT * Refactoring * Fix Integ tests * refactoring after Notebook testing * Fix code styles --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent a73e8a5 commit 9cc4fb0

File tree

14 files changed

+1185
-5
lines changed

14 files changed

+1185
-5
lines changed

src/sagemaker/jumpstart/model.py

+106-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
from __future__ import absolute_import
1616

17-
from typing import Dict, List, Optional, Union
17+
from functools import lru_cache
18+
from typing import Dict, List, Optional, Union, Any
19+
import pandas as pd
1820
from botocore.exceptions import ClientError
1921

2022
from sagemaker import payloads
@@ -36,14 +38,21 @@
3638
get_init_kwargs,
3739
get_register_kwargs,
3840
)
39-
from sagemaker.jumpstart.types import JumpStartSerializablePayload
41+
from sagemaker.jumpstart.types import (
42+
JumpStartSerializablePayload,
43+
DeploymentConfigMetadata,
44+
JumpStartBenchmarkStat,
45+
JumpStartMetadataConfig,
46+
)
4047
from sagemaker.jumpstart.utils import (
4148
validate_model_id_and_get_type,
4249
verify_model_region_and_return_specs,
50+
get_jumpstart_configs,
51+
extract_metrics_from_deployment_configs,
4352
)
4453
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
4554
from sagemaker.jumpstart.enums import JumpStartModelType
46-
from sagemaker.utils import stringify_object, format_tags, Tags
55+
from sagemaker.utils import stringify_object, format_tags, Tags, get_instance_rate_per_hour
4756
from sagemaker.model import (
4857
Model,
4958
ModelPackage,
@@ -352,6 +361,18 @@ def _validate_model_id_and_type():
352361
self.model_package_arn = model_init_kwargs.model_package_arn
353362
self.init_kwargs = model_init_kwargs.to_kwargs_dict(False)
354363

364+
metadata_configs = get_jumpstart_configs(
365+
region=self.region,
366+
model_id=self.model_id,
367+
model_version=self.model_version,
368+
sagemaker_session=self.sagemaker_session,
369+
model_type=self.model_type,
370+
)
371+
self._deployment_configs = [
372+
self._convert_to_deployment_config_metadata(config_name, config)
373+
for config_name, config in metadata_configs.items()
374+
]
375+
355376
def log_subscription_warning(self) -> None:
356377
"""Log message prompting the customer to subscribe to the proprietary model."""
357378
subscription_link = verify_model_region_and_return_specs(
@@ -420,6 +441,27 @@ def set_deployment_config(self, config_name: Optional[str]) -> None:
420441
model_id=self.model_id, model_version=self.model_version, config_name=config_name
421442
)
422443

444+
@property
445+
def benchmark_metrics(self) -> pd.DataFrame:
446+
"""Benchmark Metrics for deployment configs
447+
448+
Returns:
449+
Metrics: Pandas DataFrame object.
450+
"""
451+
return pd.DataFrame(self._get_benchmark_data(self.config_name))
452+
453+
def display_benchmark_metrics(self) -> None:
454+
"""Display Benchmark Metrics for deployment configs."""
455+
print(self.benchmark_metrics.to_markdown())
456+
457+
def list_deployment_configs(self) -> List[Dict[str, Any]]:
458+
"""List deployment configs for ``This`` model.
459+
460+
Returns:
461+
List[Dict[str, Any]]: A list of deployment configs.
462+
"""
463+
return self._deployment_configs
464+
423465
def _create_sagemaker_model(
424466
self,
425467
instance_type=None,
@@ -808,6 +850,67 @@ def register_deploy_wrapper(*args, **kwargs):
808850

809851
return model_package
810852

853+
@lru_cache
854+
def _get_benchmark_data(self, config_name: str) -> Dict[str, List[str]]:
855+
"""Constructs deployment configs benchmark data.
856+
857+
Args:
858+
config_name (str): The name of the selected deployment config.
859+
Returns:
860+
Dict[str, List[str]]: Deployment config benchmark data.
861+
"""
862+
return extract_metrics_from_deployment_configs(
863+
self._deployment_configs,
864+
config_name,
865+
)
866+
867+
def _convert_to_deployment_config_metadata(
868+
self, config_name: str, metadata_config: JumpStartMetadataConfig
869+
) -> Dict[str, Any]:
870+
"""Retrieve deployment config for config name.
871+
872+
Args:
873+
config_name (str): Name of deployment config.
874+
metadata_config (JumpStartMetadataConfig): Metadata config for deployment config.
875+
Returns:
876+
A deployment metadata config for config name (dict[str, Any]).
877+
"""
878+
default_inference_instance_type = metadata_config.resolved_config.get(
879+
"default_inference_instance_type"
880+
)
881+
882+
instance_rate = get_instance_rate_per_hour(
883+
instance_type=default_inference_instance_type, region=self.region
884+
)
885+
886+
benchmark_metrics = (
887+
metadata_config.benchmark_metrics.get(default_inference_instance_type)
888+
if metadata_config.benchmark_metrics is not None
889+
else None
890+
)
891+
if instance_rate is not None:
892+
if benchmark_metrics is not None:
893+
benchmark_metrics.append(JumpStartBenchmarkStat(instance_rate))
894+
else:
895+
benchmark_metrics = [JumpStartBenchmarkStat(instance_rate)]
896+
897+
init_kwargs = get_init_kwargs(
898+
model_id=self.model_id,
899+
instance_type=default_inference_instance_type,
900+
sagemaker_session=self.sagemaker_session,
901+
)
902+
deploy_kwargs = get_deploy_kwargs(
903+
model_id=self.model_id,
904+
instance_type=default_inference_instance_type,
905+
sagemaker_session=self.sagemaker_session,
906+
)
907+
908+
deployment_config_metadata = DeploymentConfigMetadata(
909+
config_name, benchmark_metrics, init_kwargs, deploy_kwargs
910+
)
911+
912+
return deployment_config_metadata.to_json()
913+
811914
def __str__(self) -> str:
812915
"""Overriding str(*) method to make more human-readable."""
813916
return stringify_object(self)

src/sagemaker/jumpstart/types.py

+96
Original file line numberDiff line numberDiff line change
@@ -2208,3 +2208,99 @@ def __init__(
22082208
self.skip_model_validation = skip_model_validation
22092209
self.source_uri = source_uri
22102210
self.config_name = config_name
2211+
2212+
2213+
class BaseDeploymentConfigDataHolder(JumpStartDataHolderType):
2214+
"""Base class for Deployment Config Data."""
2215+
2216+
def _convert_to_pascal_case(self, attr_name: str) -> str:
2217+
"""Converts a snake_case attribute name into a camelCased string.
2218+
2219+
Args:
2220+
attr_name (str): The snake_case attribute name.
2221+
Returns:
2222+
str: The PascalCased attribute name.
2223+
"""
2224+
return attr_name.replace("_", " ").title().replace(" ", "")
2225+
2226+
def to_json(self) -> Dict[str, Any]:
2227+
"""Represents ``This`` object as JSON."""
2228+
json_obj = {}
2229+
for att in self.__slots__:
2230+
if hasattr(self, att):
2231+
cur_val = getattr(self, att)
2232+
att = self._convert_to_pascal_case(att)
2233+
if issubclass(type(cur_val), JumpStartDataHolderType):
2234+
json_obj[att] = cur_val.to_json()
2235+
elif isinstance(cur_val, list):
2236+
json_obj[att] = []
2237+
for obj in cur_val:
2238+
if issubclass(type(obj), JumpStartDataHolderType):
2239+
json_obj[att].append(obj.to_json())
2240+
else:
2241+
json_obj[att].append(obj)
2242+
elif isinstance(cur_val, dict):
2243+
json_obj[att] = {}
2244+
for key, val in cur_val.items():
2245+
if issubclass(type(val), JumpStartDataHolderType):
2246+
json_obj[att][self._convert_to_pascal_case(key)] = val.to_json()
2247+
else:
2248+
json_obj[att][key] = val
2249+
else:
2250+
json_obj[att] = cur_val
2251+
return json_obj
2252+
2253+
2254+
class DeploymentConfig(BaseDeploymentConfigDataHolder):
2255+
"""Dataclass representing a Deployment Config."""
2256+
2257+
__slots__ = [
2258+
"model_data_download_timeout",
2259+
"container_startup_health_check_timeout",
2260+
"image_uri",
2261+
"model_data",
2262+
"instance_type",
2263+
"environment",
2264+
"compute_resource_requirements",
2265+
]
2266+
2267+
def __init__(
2268+
self, init_kwargs: JumpStartModelInitKwargs, deploy_kwargs: JumpStartModelDeployKwargs
2269+
):
2270+
"""Instantiates DeploymentConfig object."""
2271+
if init_kwargs is not None:
2272+
self.image_uri = init_kwargs.image_uri
2273+
self.model_data = init_kwargs.model_data
2274+
self.instance_type = init_kwargs.instance_type
2275+
self.environment = init_kwargs.env
2276+
if init_kwargs.resources is not None:
2277+
self.compute_resource_requirements = (
2278+
init_kwargs.resources.get_compute_resource_requirements()
2279+
)
2280+
if deploy_kwargs is not None:
2281+
self.model_data_download_timeout = deploy_kwargs.model_data_download_timeout
2282+
self.container_startup_health_check_timeout = (
2283+
deploy_kwargs.container_startup_health_check_timeout
2284+
)
2285+
2286+
2287+
class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):
2288+
"""Dataclass representing a Deployment Config Metadata"""
2289+
2290+
__slots__ = [
2291+
"config_name",
2292+
"benchmark_metrics",
2293+
"deployment_config",
2294+
]
2295+
2296+
def __init__(
2297+
self,
2298+
config_name: str,
2299+
benchmark_metrics: List[JumpStartBenchmarkStat],
2300+
init_kwargs: JumpStartModelInitKwargs,
2301+
deploy_kwargs: JumpStartModelDeployKwargs,
2302+
):
2303+
"""Instantiates DeploymentConfigMetadata object."""
2304+
self.config_name = config_name
2305+
self.benchmark_metrics = benchmark_metrics
2306+
self.deployment_config = DeploymentConfig(init_kwargs, deploy_kwargs)

src/sagemaker/jumpstart/utils.py

+41
Original file line numberDiff line numberDiff line change
@@ -999,3 +999,44 @@ def get_jumpstart_configs(
999999
if metadata_configs
10001000
else {}
10011001
)
1002+
1003+
1004+
def extract_metrics_from_deployment_configs(
1005+
deployment_configs: List[Dict[str, Any]], config_name: str
1006+
) -> Dict[str, List[str]]:
1007+
"""Extracts metrics from deployment configs.
1008+
1009+
Args:
1010+
deployment_configs (list[dict[str, Any]]): List of deployment configs.
1011+
config_name (str): The name of the deployment config use by the model.
1012+
"""
1013+
1014+
data = {"Config Name": [], "Instance Type": [], "Selected": []}
1015+
1016+
for index, deployment_config in enumerate(deployment_configs):
1017+
if deployment_config.get("DeploymentConfig") is None:
1018+
continue
1019+
1020+
benchmark_metrics = deployment_config.get("BenchmarkMetrics")
1021+
if benchmark_metrics is not None:
1022+
data["Config Name"].append(deployment_config.get("ConfigName"))
1023+
data["Instance Type"].append(
1024+
deployment_config.get("DeploymentConfig").get("InstanceType")
1025+
)
1026+
data["Selected"].append(
1027+
"Yes"
1028+
if (config_name is not None and config_name == deployment_config.get("ConfigName"))
1029+
else "No"
1030+
)
1031+
1032+
if index == 0:
1033+
for benchmark_metric in benchmark_metrics:
1034+
column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})"
1035+
data[column_name] = []
1036+
1037+
for benchmark_metric in benchmark_metrics:
1038+
column_name = f"{benchmark_metric.get('name')} ({benchmark_metric.get('unit')})"
1039+
if column_name in data.keys():
1040+
data[column_name].append(benchmark_metric.get("value"))
1041+
1042+
return data

src/sagemaker/serve/builder/jumpstart_builder.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import copy
1717
from abc import ABC, abstractmethod
1818
from datetime import datetime, timedelta
19-
from typing import Type
19+
from typing import Type, Any, List, Dict
2020
import logging
2121

2222
from sagemaker.model import Model
@@ -431,6 +431,18 @@ def tune_for_tgi_jumpstart(self, max_tuning_duration: int = 1800):
431431
sharded_supported=sharded_supported, max_tuning_duration=max_tuning_duration
432432
)
433433

434+
def display_benchmark_metrics(self):
435+
"""Display Markdown Benchmark Metrics for deployment configs."""
436+
self.pysdk_model.display_benchmark_metrics()
437+
438+
def list_deployment_configs(self) -> List[Dict[str, Any]]:
439+
"""List deployment configs for ``This`` model in the current region.
440+
441+
Returns:
442+
List[Dict[str, Any]]: A list of deployment configs.
443+
"""
444+
return self.pysdk_model.list_deployment_configs()
445+
434446
def _build_for_jumpstart(self):
435447
"""Placeholder docstring"""
436448
# we do not pickle for jumpstart. set to none

0 commit comments

Comments
 (0)