Skip to content

Commit 6f24a4f

Browse files
makungaj1Jonathan Makunga
and
Jonathan Makunga
authored
Deployment Configs - Follow-ups (#4626)
* Init Deployment configs outside Model init. * Testing with NB * Testing with NB-V2 * Refactoring, NB testing * NB Testing and Refactoring * Testing * Refactoring * Testing with NB * Debug * Debug display API * Debug with NB * Testing with NB * Refactoring * Refactoring * Refactoring and NB testing * Testing with NB * Refactoring * Prefix instance type with ml * Fix unit tests --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 4891361 commit 6f24a4f

File tree

11 files changed

+582
-307
lines changed

11 files changed

+582
-307
lines changed

src/sagemaker/jumpstart/model.py

+85-75
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,17 @@
4141
from sagemaker.jumpstart.types import (
4242
JumpStartSerializablePayload,
4343
DeploymentConfigMetadata,
44-
JumpStartBenchmarkStat,
45-
JumpStartMetadataConfig,
4644
)
4745
from sagemaker.jumpstart.utils import (
4846
validate_model_id_and_get_type,
4947
verify_model_region_and_return_specs,
5048
get_jumpstart_configs,
5149
get_metrics_from_deployment_configs,
50+
add_instance_rate_stats_to_benchmark_metrics,
5251
)
5352
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
5453
from sagemaker.jumpstart.enums import JumpStartModelType
55-
from sagemaker.utils import stringify_object, format_tags, Tags, get_instance_rate_per_hour
54+
from sagemaker.utils import stringify_object, format_tags, Tags
5655
from sagemaker.model import (
5756
Model,
5857
ModelPackage,
@@ -361,17 +360,13 @@ def _validate_model_id_and_type():
361360
self.model_package_arn = model_init_kwargs.model_package_arn
362361
self.init_kwargs = model_init_kwargs.to_kwargs_dict(False)
363362

364-
metadata_configs = get_jumpstart_configs(
363+
self._metadata_configs = get_jumpstart_configs(
365364
region=self.region,
366365
model_id=self.model_id,
367366
model_version=self.model_version,
368367
sagemaker_session=self.sagemaker_session,
369368
model_type=self.model_type,
370369
)
371-
self._deployment_configs = [
372-
self._convert_to_deployment_config_metadata(config_name, config)
373-
for config_name, config in metadata_configs.items()
374-
]
375370

376371
def log_subscription_warning(self) -> None:
377372
"""Log message prompting the customer to subscribe to the proprietary model."""
@@ -449,33 +444,46 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None:
449444

450445
@property
451446
def deployment_config(self) -> Optional[Dict[str, Any]]:
452-
"""The deployment config that will be applied to the model.
447+
"""The deployment config that will be applied to ``This`` model.
453448
454449
Returns:
455-
Optional[Dict[str, Any]]: Deployment config that will be applied to the model.
450+
Optional[Dict[str, Any]]: Deployment config.
456451
"""
457-
return self._retrieve_selected_deployment_config(self.config_name)
452+
deployment_config = self._retrieve_selected_deployment_config(
453+
self.config_name, self.instance_type
454+
)
455+
return deployment_config.to_json() if deployment_config is not None else None
458456

459457
@property
460458
def benchmark_metrics(self) -> pd.DataFrame:
461-
"""Benchmark Metrics for deployment configs
459+
"""Benchmark Metrics for deployment configs.
462460
463461
Returns:
464-
Metrics: Pandas DataFrame object.
462+
Benchmark Metrics: Pandas DataFrame object.
465463
"""
466-
return pd.DataFrame(self._get_benchmarks_data(self.config_name))
464+
benchmark_metrics_data = self._get_deployment_configs_benchmarks_data(
465+
self.config_name, self.instance_type
466+
)
467+
keys = list(benchmark_metrics_data.keys())
468+
df = pd.DataFrame(benchmark_metrics_data).sort_values(by=[keys[0], keys[1]])
469+
return df
467470

468471
def display_benchmark_metrics(self) -> None:
469-
"""Display Benchmark Metrics for deployment configs."""
470-
print(self.benchmark_metrics.to_markdown())
472+
"""Display deployment configs benchmark metrics."""
473+
print(self.benchmark_metrics.to_markdown(index=False))
471474

472475
def list_deployment_configs(self) -> List[Dict[str, Any]]:
473476
"""List deployment configs for ``This`` model.
474477
475478
Returns:
476479
List[Dict[str, Any]]: A list of deployment configs.
477480
"""
478-
return self._deployment_configs
481+
return [
482+
deployment_config.to_json()
483+
for deployment_config in self._get_deployment_configs(
484+
self.config_name, self.instance_type
485+
)
486+
]
479487

480488
def _create_sagemaker_model(
481489
self,
@@ -866,92 +874,94 @@ def register_deploy_wrapper(*args, **kwargs):
866874
return model_package
867875

868876
@lru_cache
869-
def _get_benchmarks_data(self, config_name: str) -> Dict[str, List[str]]:
877+
def _get_deployment_configs_benchmarks_data(
878+
self, config_name: str, instance_type: str
879+
) -> Dict[str, Any]:
870880
"""Deployment configs benchmark metrics.
871881
872882
Args:
873-
config_name (str): The name of the selected deployment config.
883+
config_name (str): Name of selected deployment config.
884+
instance_type (str): The selected Instance type.
874885
Returns:
875886
Dict[str, List[str]]: Deployment config benchmark data.
876887
"""
877888
return get_metrics_from_deployment_configs(
878-
self._deployment_configs,
879-
config_name,
889+
self._get_deployment_configs(config_name, instance_type)
880890
)
881891

882892
@lru_cache
883-
def _retrieve_selected_deployment_config(self, config_name: str) -> Optional[Dict[str, Any]]:
884-
"""Retrieve the deployment config to apply to the model.
893+
def _retrieve_selected_deployment_config(
894+
self, config_name: str, instance_type: str
895+
) -> Optional[DeploymentConfigMetadata]:
896+
"""Retrieve the deployment config to apply to `This` model.
885897
886898
Args:
887899
config_name (str): The name of the deployment config to retrieve.
900+
instance_type (str): The instance type of the deployment config to retrieve.
888901
Returns:
889902
Optional[Dict[str, Any]]: The retrieved deployment config.
890903
"""
891904
if config_name is None:
892905
return None
893906

894-
for deployment_config in self._deployment_configs:
895-
if deployment_config.get("DeploymentConfigName") == config_name:
907+
for deployment_config in self._get_deployment_configs(config_name, instance_type):
908+
if deployment_config.deployment_config_name == config_name:
896909
return deployment_config
897910
return None
898911

899-
def _convert_to_deployment_config_metadata(
900-
self, config_name: str, metadata_config: JumpStartMetadataConfig
901-
) -> Dict[str, Any]:
902-
"""Retrieve deployment config for config name.
912+
@lru_cache
913+
def _get_deployment_configs(
914+
self, selected_config_name: str, selected_instance_type: str
915+
) -> List[DeploymentConfigMetadata]:
916+
"""Retrieve deployment configs metadata.
903917
904918
Args:
905-
config_name (str): Name of deployment config.
906-
metadata_config (JumpStartMetadataConfig): Metadata config for deployment config.
907-
Returns:
908-
A deployment metadata config for config name (dict[str, Any]).
919+
selected_config_name (str): The name of the selected deployment config.
920+
selected_instance_type (str): The selected instance type.
909921
"""
910-
default_inference_instance_type = metadata_config.resolved_config.get(
911-
"default_inference_instance_type"
912-
)
913-
914-
benchmark_metrics = (
915-
metadata_config.benchmark_metrics.get(default_inference_instance_type)
916-
if metadata_config.benchmark_metrics is not None
917-
else None
918-
)
919-
920-
should_fetch_instance_rate_metric = True
921-
if benchmark_metrics is not None:
922-
for benchmark_metric in benchmark_metrics:
923-
if benchmark_metric.name.lower() == "instance rate":
924-
should_fetch_instance_rate_metric = False
925-
break
926-
927-
if should_fetch_instance_rate_metric:
928-
instance_rate = get_instance_rate_per_hour(
929-
instance_type=default_inference_instance_type, region=self.region
922+
deployment_configs = []
923+
if self._metadata_configs is None:
924+
return deployment_configs
925+
926+
err = None
927+
for config_name, metadata_config in self._metadata_configs.items():
928+
if err is None or "is not authorized to perform: pricing:GetProducts" not in err:
929+
err, metadata_config.benchmark_metrics = (
930+
add_instance_rate_stats_to_benchmark_metrics(
931+
self.region, metadata_config.benchmark_metrics
932+
)
933+
)
934+
935+
resolved_config = metadata_config.resolved_config
936+
if selected_config_name == config_name:
937+
instance_type_to_use = selected_instance_type
938+
else:
939+
instance_type_to_use = resolved_config.get("default_inference_instance_type")
940+
941+
init_kwargs = get_init_kwargs(
942+
model_id=self.model_id,
943+
instance_type=instance_type_to_use,
944+
sagemaker_session=self.sagemaker_session,
930945
)
931-
if instance_rate is not None:
932-
instance_rate_metric = JumpStartBenchmarkStat(instance_rate)
933-
934-
if benchmark_metrics is None:
935-
benchmark_metrics = [instance_rate_metric]
936-
else:
937-
benchmark_metrics.append(instance_rate_metric)
938-
939-
init_kwargs = get_init_kwargs(
940-
model_id=self.model_id,
941-
instance_type=default_inference_instance_type,
942-
sagemaker_session=self.sagemaker_session,
943-
)
944-
deploy_kwargs = get_deploy_kwargs(
945-
model_id=self.model_id,
946-
instance_type=default_inference_instance_type,
947-
sagemaker_session=self.sagemaker_session,
948-
)
946+
deploy_kwargs = get_deploy_kwargs(
947+
model_id=self.model_id,
948+
instance_type=instance_type_to_use,
949+
sagemaker_session=self.sagemaker_session,
950+
)
951+
deployment_config_metadata = DeploymentConfigMetadata(
952+
config_name,
953+
metadata_config.benchmark_metrics,
954+
resolved_config,
955+
init_kwargs,
956+
deploy_kwargs,
957+
)
958+
deployment_configs.append(deployment_config_metadata)
949959

950-
deployment_config_metadata = DeploymentConfigMetadata(
951-
config_name, benchmark_metrics, init_kwargs, deploy_kwargs
952-
)
960+
if err is not None and "is not authorized to perform: pricing:GetProducts" in err:
961+
error_message = "Instance rate metrics will be omitted. Reason: %s"
962+
JUMPSTART_LOGGER.warning(error_message, err)
953963

954-
return deployment_config_metadata.to_json()
964+
return deployment_configs
955965

956966
def __str__(self) -> str:
957967
"""Overriding str(*) method to make more human-readable."""

src/sagemaker/jumpstart/types.py

+45-27
Original file line numberDiff line numberDiff line change
@@ -2235,29 +2235,37 @@ def to_json(self) -> Dict[str, Any]:
22352235
if hasattr(self, att):
22362236
cur_val = getattr(self, att)
22372237
att = self._convert_to_pascal_case(att)
2238-
if issubclass(type(cur_val), JumpStartDataHolderType):
2239-
json_obj[att] = cur_val.to_json()
2240-
elif isinstance(cur_val, list):
2241-
json_obj[att] = []
2242-
for obj in cur_val:
2243-
if issubclass(type(obj), JumpStartDataHolderType):
2244-
json_obj[att].append(obj.to_json())
2245-
else:
2246-
json_obj[att].append(obj)
2247-
elif isinstance(cur_val, dict):
2248-
json_obj[att] = {}
2249-
for key, val in cur_val.items():
2250-
if issubclass(type(val), JumpStartDataHolderType):
2251-
json_obj[att][self._convert_to_pascal_case(key)] = val.to_json()
2252-
else:
2253-
json_obj[att][key] = val
2254-
else:
2255-
json_obj[att] = cur_val
2238+
json_obj[att] = self._val_to_json(cur_val)
22562239
return json_obj
22572240

2241+
def _val_to_json(self, val: Any) -> Any:
2242+
"""Converts the given value to JSON.
2243+
2244+
Args:
2245+
val (Any): The value to convert.
2246+
Returns:
2247+
Any: The converted json value.
2248+
"""
2249+
if issubclass(type(val), JumpStartDataHolderType):
2250+
return val.to_json()
2251+
if isinstance(val, list):
2252+
list_obj = []
2253+
for obj in val:
2254+
list_obj.append(self._val_to_json(obj))
2255+
return list_obj
2256+
if isinstance(val, dict):
2257+
dict_obj = {}
2258+
for k, v in val.items():
2259+
if isinstance(v, JumpStartDataHolderType):
2260+
dict_obj[self._convert_to_pascal_case(k)] = self._val_to_json(v)
2261+
else:
2262+
dict_obj[k] = self._val_to_json(v)
2263+
return dict_obj
2264+
return val
2265+
22582266

22592267
class DeploymentArgs(BaseDeploymentConfigDataHolder):
2260-
"""Dataclass representing a Deployment Config."""
2268+
"""Dataclass representing a Deployment Args."""
22612269

22622270
__slots__ = [
22632271
"image_uri",
@@ -2270,9 +2278,12 @@ class DeploymentArgs(BaseDeploymentConfigDataHolder):
22702278
]
22712279

22722280
def __init__(
2273-
self, init_kwargs: JumpStartModelInitKwargs, deploy_kwargs: JumpStartModelDeployKwargs
2281+
self,
2282+
init_kwargs: Optional[JumpStartModelInitKwargs] = None,
2283+
deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None,
2284+
resolved_config: Optional[Dict[str, Any]] = None,
22742285
):
2275-
"""Instantiates DeploymentConfig object."""
2286+
"""Instantiates DeploymentArgs object."""
22762287
if init_kwargs is not None:
22772288
self.image_uri = init_kwargs.image_uri
22782289
self.model_data = init_kwargs.model_data
@@ -2287,6 +2298,11 @@ def __init__(
22872298
self.container_startup_health_check_timeout = (
22882299
deploy_kwargs.container_startup_health_check_timeout
22892300
)
2301+
if resolved_config is not None:
2302+
self.default_instance_type = resolved_config.get("default_inference_instance_type")
2303+
self.supported_instance_types = resolved_config.get(
2304+
"supported_inference_instance_types"
2305+
)
22902306

22912307

22922308
class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):
@@ -2301,13 +2317,15 @@ class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):
23012317

23022318
def __init__(
23032319
self,
2304-
config_name: str,
2305-
benchmark_metrics: List[JumpStartBenchmarkStat],
2306-
init_kwargs: JumpStartModelInitKwargs,
2307-
deploy_kwargs: JumpStartModelDeployKwargs,
2320+
config_name: Optional[str] = None,
2321+
benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]] = None,
2322+
resolved_config: Optional[Dict[str, Any]] = None,
2323+
init_kwargs: Optional[JumpStartModelInitKwargs] = None,
2324+
deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None,
23082325
):
23092326
"""Instantiates DeploymentConfigMetadata object."""
23102327
self.deployment_config_name = config_name
2311-
self.deployment_args = DeploymentArgs(init_kwargs, deploy_kwargs)
2312-
self.acceleration_configs = None
2328+
self.deployment_args = DeploymentArgs(init_kwargs, deploy_kwargs, resolved_config)
23132329
self.benchmark_metrics = benchmark_metrics
2330+
if resolved_config is not None:
2331+
self.acceleration_configs = resolved_config.get("acceleration_configs")

0 commit comments

Comments
 (0)