Skip to content

Commit c537f97

Browse files
makungaj1Jonathan Makunga
authored andcommitted
Deployment Configs - Follow-ups (aws#4626)
* Init Deployment configs outside Model init. * Testing with NB * Testing with NB-V2 * Refactoring, NB testing * NB Testing and Refactoring * Testing * Refactoring * Testing with NB * Debug * Debug display API * Debug with NB * Testing with NB * Refactoring * Refactoring * Refactoring and NB testing * Testing with NB * Refactoring * Prefix instance type with ml * Fix unit tests --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 2040564 commit c537f97

File tree

11 files changed

+582
-307
lines changed

11 files changed

+582
-307
lines changed

src/sagemaker/jumpstart/model.py

+85-75
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,17 @@
4141
from sagemaker.jumpstart.types import (
4242
JumpStartSerializablePayload,
4343
DeploymentConfigMetadata,
44-
JumpStartBenchmarkStat,
45-
JumpStartMetadataConfig,
4644
)
4745
from sagemaker.jumpstart.utils import (
4846
validate_model_id_and_get_type,
4947
verify_model_region_and_return_specs,
5048
get_jumpstart_configs,
5149
get_metrics_from_deployment_configs,
50+
add_instance_rate_stats_to_benchmark_metrics,
5251
)
5352
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
5453
from sagemaker.jumpstart.enums import JumpStartModelType
55-
from sagemaker.utils import stringify_object, format_tags, Tags, get_instance_rate_per_hour
54+
from sagemaker.utils import stringify_object, format_tags, Tags
5655
from sagemaker.model import (
5756
Model,
5857
ModelPackage,
@@ -361,17 +360,13 @@ def _validate_model_id_and_type():
361360
self.model_package_arn = model_init_kwargs.model_package_arn
362361
self.init_kwargs = model_init_kwargs.to_kwargs_dict(False)
363362

364-
metadata_configs = get_jumpstart_configs(
363+
self._metadata_configs = get_jumpstart_configs(
365364
region=self.region,
366365
model_id=self.model_id,
367366
model_version=self.model_version,
368367
sagemaker_session=self.sagemaker_session,
369368
model_type=self.model_type,
370369
)
371-
self._deployment_configs = [
372-
self._convert_to_deployment_config_metadata(config_name, config)
373-
for config_name, config in metadata_configs.items()
374-
]
375370

376371
def log_subscription_warning(self) -> None:
377372
"""Log message prompting the customer to subscribe to the proprietary model."""
@@ -449,33 +444,46 @@ def set_deployment_config(self, config_name: str, instance_type: str) -> None:
449444

450445
@property
451446
def deployment_config(self) -> Optional[Dict[str, Any]]:
452-
"""The deployment config that will be applied to the model.
447+
"""The deployment config that will be applied to ``This`` model.
453448
454449
Returns:
455-
Optional[Dict[str, Any]]: Deployment config that will be applied to the model.
450+
Optional[Dict[str, Any]]: Deployment config.
456451
"""
457-
return self._retrieve_selected_deployment_config(self.config_name)
452+
deployment_config = self._retrieve_selected_deployment_config(
453+
self.config_name, self.instance_type
454+
)
455+
return deployment_config.to_json() if deployment_config is not None else None
458456

459457
@property
460458
def benchmark_metrics(self) -> pd.DataFrame:
461-
"""Benchmark Metrics for deployment configs
459+
"""Benchmark Metrics for deployment configs.
462460
463461
Returns:
464-
Metrics: Pandas DataFrame object.
462+
Benchmark Metrics: Pandas DataFrame object.
465463
"""
466-
return pd.DataFrame(self._get_benchmarks_data(self.config_name))
464+
benchmark_metrics_data = self._get_deployment_configs_benchmarks_data(
465+
self.config_name, self.instance_type
466+
)
467+
keys = list(benchmark_metrics_data.keys())
468+
df = pd.DataFrame(benchmark_metrics_data).sort_values(by=[keys[0], keys[1]])
469+
return df
467470

468471
def display_benchmark_metrics(self) -> None:
469-
"""Display Benchmark Metrics for deployment configs."""
470-
print(self.benchmark_metrics.to_markdown())
472+
"""Display deployment configs benchmark metrics."""
473+
print(self.benchmark_metrics.to_markdown(index=False))
471474

472475
def list_deployment_configs(self) -> List[Dict[str, Any]]:
473476
"""List deployment configs for ``This`` model.
474477
475478
Returns:
476479
List[Dict[str, Any]]: A list of deployment configs.
477480
"""
478-
return self._deployment_configs
481+
return [
482+
deployment_config.to_json()
483+
for deployment_config in self._get_deployment_configs(
484+
self.config_name, self.instance_type
485+
)
486+
]
479487

480488
def _create_sagemaker_model(
481489
self,
@@ -866,92 +874,94 @@ def register_deploy_wrapper(*args, **kwargs):
866874
return model_package
867875

868876
@lru_cache
869-
def _get_benchmarks_data(self, config_name: str) -> Dict[str, List[str]]:
877+
def _get_deployment_configs_benchmarks_data(
878+
self, config_name: str, instance_type: str
879+
) -> Dict[str, Any]:
870880
"""Deployment configs benchmark metrics.
871881
872882
Args:
873-
config_name (str): The name of the selected deployment config.
883+
config_name (str): Name of selected deployment config.
884+
instance_type (str): The selected Instance type.
874885
Returns:
875886
Dict[str, List[str]]: Deployment config benchmark data.
876887
"""
877888
return get_metrics_from_deployment_configs(
878-
self._deployment_configs,
879-
config_name,
889+
self._get_deployment_configs(config_name, instance_type)
880890
)
881891

882892
@lru_cache
883-
def _retrieve_selected_deployment_config(self, config_name: str) -> Optional[Dict[str, Any]]:
884-
"""Retrieve the deployment config to apply to the model.
893+
def _retrieve_selected_deployment_config(
894+
self, config_name: str, instance_type: str
895+
) -> Optional[DeploymentConfigMetadata]:
896+
"""Retrieve the deployment config to apply to `This` model.
885897
886898
Args:
887899
config_name (str): The name of the deployment config to retrieve.
900+
instance_type (str): The instance type of the deployment config to retrieve.
888901
Returns:
889902
Optional[Dict[str, Any]]: The retrieved deployment config.
890903
"""
891904
if config_name is None:
892905
return None
893906

894-
for deployment_config in self._deployment_configs:
895-
if deployment_config.get("DeploymentConfigName") == config_name:
907+
for deployment_config in self._get_deployment_configs(config_name, instance_type):
908+
if deployment_config.deployment_config_name == config_name:
896909
return deployment_config
897910
return None
898911

899-
def _convert_to_deployment_config_metadata(
900-
self, config_name: str, metadata_config: JumpStartMetadataConfig
901-
) -> Dict[str, Any]:
902-
"""Retrieve deployment config for config name.
912+
@lru_cache
913+
def _get_deployment_configs(
914+
self, selected_config_name: str, selected_instance_type: str
915+
) -> List[DeploymentConfigMetadata]:
916+
"""Retrieve deployment configs metadata.
903917
904918
Args:
905-
config_name (str): Name of deployment config.
906-
metadata_config (JumpStartMetadataConfig): Metadata config for deployment config.
907-
Returns:
908-
A deployment metadata config for config name (dict[str, Any]).
919+
selected_config_name (str): The name of the selected deployment config.
920+
selected_instance_type (str): The selected instance type.
909921
"""
910-
default_inference_instance_type = metadata_config.resolved_config.get(
911-
"default_inference_instance_type"
912-
)
913-
914-
benchmark_metrics = (
915-
metadata_config.benchmark_metrics.get(default_inference_instance_type)
916-
if metadata_config.benchmark_metrics is not None
917-
else None
918-
)
919-
920-
should_fetch_instance_rate_metric = True
921-
if benchmark_metrics is not None:
922-
for benchmark_metric in benchmark_metrics:
923-
if benchmark_metric.name.lower() == "instance rate":
924-
should_fetch_instance_rate_metric = False
925-
break
926-
927-
if should_fetch_instance_rate_metric:
928-
instance_rate = get_instance_rate_per_hour(
929-
instance_type=default_inference_instance_type, region=self.region
922+
deployment_configs = []
923+
if self._metadata_configs is None:
924+
return deployment_configs
925+
926+
err = None
927+
for config_name, metadata_config in self._metadata_configs.items():
928+
if err is None or "is not authorized to perform: pricing:GetProducts" not in err:
929+
err, metadata_config.benchmark_metrics = (
930+
add_instance_rate_stats_to_benchmark_metrics(
931+
self.region, metadata_config.benchmark_metrics
932+
)
933+
)
934+
935+
resolved_config = metadata_config.resolved_config
936+
if selected_config_name == config_name:
937+
instance_type_to_use = selected_instance_type
938+
else:
939+
instance_type_to_use = resolved_config.get("default_inference_instance_type")
940+
941+
init_kwargs = get_init_kwargs(
942+
model_id=self.model_id,
943+
instance_type=instance_type_to_use,
944+
sagemaker_session=self.sagemaker_session,
930945
)
931-
if instance_rate is not None:
932-
instance_rate_metric = JumpStartBenchmarkStat(instance_rate)
933-
934-
if benchmark_metrics is None:
935-
benchmark_metrics = [instance_rate_metric]
936-
else:
937-
benchmark_metrics.append(instance_rate_metric)
938-
939-
init_kwargs = get_init_kwargs(
940-
model_id=self.model_id,
941-
instance_type=default_inference_instance_type,
942-
sagemaker_session=self.sagemaker_session,
943-
)
944-
deploy_kwargs = get_deploy_kwargs(
945-
model_id=self.model_id,
946-
instance_type=default_inference_instance_type,
947-
sagemaker_session=self.sagemaker_session,
948-
)
946+
deploy_kwargs = get_deploy_kwargs(
947+
model_id=self.model_id,
948+
instance_type=instance_type_to_use,
949+
sagemaker_session=self.sagemaker_session,
950+
)
951+
deployment_config_metadata = DeploymentConfigMetadata(
952+
config_name,
953+
metadata_config.benchmark_metrics,
954+
resolved_config,
955+
init_kwargs,
956+
deploy_kwargs,
957+
)
958+
deployment_configs.append(deployment_config_metadata)
949959

950-
deployment_config_metadata = DeploymentConfigMetadata(
951-
config_name, benchmark_metrics, init_kwargs, deploy_kwargs
952-
)
960+
if err is not None and "is not authorized to perform: pricing:GetProducts" in err:
961+
error_message = "Instance rate metrics will be omitted. Reason: %s"
962+
JUMPSTART_LOGGER.warning(error_message, err)
953963

954-
return deployment_config_metadata.to_json()
964+
return deployment_configs
955965

956966
def __str__(self) -> str:
957967
"""Overriding str(*) method to make more human-readable."""

src/sagemaker/jumpstart/types.py

+45-27
Original file line numberDiff line numberDiff line change
@@ -2238,29 +2238,37 @@ def to_json(self) -> Dict[str, Any]:
22382238
if hasattr(self, att):
22392239
cur_val = getattr(self, att)
22402240
att = self._convert_to_pascal_case(att)
2241-
if issubclass(type(cur_val), JumpStartDataHolderType):
2242-
json_obj[att] = cur_val.to_json()
2243-
elif isinstance(cur_val, list):
2244-
json_obj[att] = []
2245-
for obj in cur_val:
2246-
if issubclass(type(obj), JumpStartDataHolderType):
2247-
json_obj[att].append(obj.to_json())
2248-
else:
2249-
json_obj[att].append(obj)
2250-
elif isinstance(cur_val, dict):
2251-
json_obj[att] = {}
2252-
for key, val in cur_val.items():
2253-
if issubclass(type(val), JumpStartDataHolderType):
2254-
json_obj[att][self._convert_to_pascal_case(key)] = val.to_json()
2255-
else:
2256-
json_obj[att][key] = val
2257-
else:
2258-
json_obj[att] = cur_val
2241+
json_obj[att] = self._val_to_json(cur_val)
22592242
return json_obj
22602243

2244+
def _val_to_json(self, val: Any) -> Any:
2245+
"""Converts the given value to JSON.
2246+
2247+
Args:
2248+
val (Any): The value to convert.
2249+
Returns:
2250+
Any: The converted json value.
2251+
"""
2252+
if issubclass(type(val), JumpStartDataHolderType):
2253+
return val.to_json()
2254+
if isinstance(val, list):
2255+
list_obj = []
2256+
for obj in val:
2257+
list_obj.append(self._val_to_json(obj))
2258+
return list_obj
2259+
if isinstance(val, dict):
2260+
dict_obj = {}
2261+
for k, v in val.items():
2262+
if isinstance(v, JumpStartDataHolderType):
2263+
dict_obj[self._convert_to_pascal_case(k)] = self._val_to_json(v)
2264+
else:
2265+
dict_obj[k] = self._val_to_json(v)
2266+
return dict_obj
2267+
return val
2268+
22612269

22622270
class DeploymentArgs(BaseDeploymentConfigDataHolder):
2263-
"""Dataclass representing a Deployment Config."""
2271+
"""Dataclass representing a Deployment Args."""
22642272

22652273
__slots__ = [
22662274
"image_uri",
@@ -2273,9 +2281,12 @@ class DeploymentArgs(BaseDeploymentConfigDataHolder):
22732281
]
22742282

22752283
def __init__(
2276-
self, init_kwargs: JumpStartModelInitKwargs, deploy_kwargs: JumpStartModelDeployKwargs
2284+
self,
2285+
init_kwargs: Optional[JumpStartModelInitKwargs] = None,
2286+
deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None,
2287+
resolved_config: Optional[Dict[str, Any]] = None,
22772288
):
2278-
"""Instantiates DeploymentConfig object."""
2289+
"""Instantiates DeploymentArgs object."""
22792290
if init_kwargs is not None:
22802291
self.image_uri = init_kwargs.image_uri
22812292
self.model_data = init_kwargs.model_data
@@ -2290,6 +2301,11 @@ def __init__(
22902301
self.container_startup_health_check_timeout = (
22912302
deploy_kwargs.container_startup_health_check_timeout
22922303
)
2304+
if resolved_config is not None:
2305+
self.default_instance_type = resolved_config.get("default_inference_instance_type")
2306+
self.supported_instance_types = resolved_config.get(
2307+
"supported_inference_instance_types"
2308+
)
22932309

22942310

22952311
class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):
@@ -2304,13 +2320,15 @@ class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):
23042320

23052321
def __init__(
23062322
self,
2307-
config_name: str,
2308-
benchmark_metrics: List[JumpStartBenchmarkStat],
2309-
init_kwargs: JumpStartModelInitKwargs,
2310-
deploy_kwargs: JumpStartModelDeployKwargs,
2323+
config_name: Optional[str] = None,
2324+
benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]] = None,
2325+
resolved_config: Optional[Dict[str, Any]] = None,
2326+
init_kwargs: Optional[JumpStartModelInitKwargs] = None,
2327+
deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None,
23112328
):
23122329
"""Instantiates DeploymentConfigMetadata object."""
23132330
self.deployment_config_name = config_name
2314-
self.deployment_args = DeploymentArgs(init_kwargs, deploy_kwargs)
2315-
self.acceleration_configs = None
2331+
self.deployment_args = DeploymentArgs(init_kwargs, deploy_kwargs, resolved_config)
23162332
self.benchmark_metrics = benchmark_metrics
2333+
if resolved_config is not None:
2334+
self.acceleration_configs = resolved_config.get("acceleration_configs")

0 commit comments

Comments
 (0)