Skip to content

Commit 2be02e0

Browse files
makungaj1Jonathan Makunga
and
Jonathan Makunga
authored
Update: ReadOnly APIs (aws#4707)
* Model data arn * Refactoring * Refactoring * acceleration_configs * Refactoring * UT * Add Filter * UT * Revert "UT" * UT * UT --------- Co-authored-by: Jonathan Makunga <[email protected]>
1 parent 06d9eb2 commit 2be02e0

File tree

6 files changed

+26
-17
lines changed

6 files changed

+26
-17
lines changed

src/sagemaker/jumpstart/model.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -469,9 +469,15 @@ def benchmark_metrics(self) -> pd.DataFrame:
469469
df.index = blank_index
470470
return df
471471

472-
def display_benchmark_metrics(self, *args, **kwargs) -> None:
472+
def display_benchmark_metrics(self, **kwargs) -> None:
473473
"""Display deployment configs benchmark metrics."""
474-
print(self.benchmark_metrics.to_markdown(index=False, floatfmt=".2f"), *args, **kwargs)
474+
df = self.benchmark_metrics
475+
476+
instance_type = kwargs.get("instance_type")
477+
if instance_type:
478+
df = df[df["Instance Type"].str.contains(instance_type)]
479+
480+
print(df.to_markdown(index=False, floatfmt=".2f"))
475481

476482
def list_deployment_configs(self) -> List[Dict[str, Any]]:
477483
"""List deployment configs for ``This`` model.
@@ -898,11 +904,12 @@ def _get_deployment_configs(
898904

899905
err = None
900906
for config_name, metadata_config in self._metadata_configs.items():
901-
resolved_config = metadata_config.resolved_config
902907
if selected_config_name == config_name:
903908
instance_type_to_use = selected_instance_type
904909
else:
905-
instance_type_to_use = resolved_config.get("default_inference_instance_type")
910+
instance_type_to_use = metadata_config.resolved_config.get(
911+
"default_inference_instance_type"
912+
)
906913

907914
if metadata_config.benchmark_metrics:
908915
err, metadata_config.benchmark_metrics = (
@@ -941,8 +948,7 @@ def _get_deployment_configs(
941948

942949
deployment_config_metadata = DeploymentConfigMetadata(
943950
config_name,
944-
metadata_config.benchmark_metrics,
945-
resolved_config,
951+
metadata_config,
946952
init_kwargs,
947953
deploy_kwargs,
948954
)

src/sagemaker/jumpstart/types.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,6 +1078,7 @@ class JumpStartMetadataConfig(JumpStartDataHolderType):
10781078
__slots__ = [
10791079
"base_fields",
10801080
"benchmark_metrics",
1081+
"acceleration_configs",
10811082
"config_components",
10821083
"resolved_metadata_config",
10831084
"config_name",
@@ -1115,6 +1116,7 @@ def __init__(
11151116
if config and config.get("benchmark_metrics")
11161117
else None
11171118
)
1119+
self.acceleration_configs = config.get("acceleration_configs")
11181120
self.resolved_metadata_config: Optional[Dict[str, Any]] = None
11191121
self.config_name: Optional[str] = config_name
11201122
self.default_inference_config: Optional[str] = config.get("default_inference_config")
@@ -2293,6 +2295,7 @@ class DeploymentArgs(BaseDeploymentConfigDataHolder):
22932295
__slots__ = [
22942296
"image_uri",
22952297
"model_data",
2298+
"model_package_arn",
22962299
"environment",
22972300
"instance_type",
22982301
"compute_resource_requirements",
@@ -2310,6 +2313,7 @@ def __init__(
23102313
if init_kwargs is not None:
23112314
self.image_uri = init_kwargs.image_uri
23122315
self.model_data = init_kwargs.model_data
2316+
self.model_package_arn = init_kwargs.model_package_arn
23132317
self.instance_type = init_kwargs.instance_type
23142318
self.environment = init_kwargs.env
23152319
if init_kwargs.resources is not None:
@@ -2341,14 +2345,14 @@ class DeploymentConfigMetadata(BaseDeploymentConfigDataHolder):
23412345
def __init__(
23422346
self,
23432347
config_name: Optional[str] = None,
2344-
benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]] = None,
2345-
resolved_config: Optional[Dict[str, Any]] = None,
2348+
metadata_config: Optional[JumpStartMetadataConfig] = None,
23462349
init_kwargs: Optional[JumpStartModelInitKwargs] = None,
23472350
deploy_kwargs: Optional[JumpStartModelDeployKwargs] = None,
23482351
):
23492352
"""Instantiates DeploymentConfigMetadata object."""
23502353
self.deployment_config_name = config_name
2351-
self.deployment_args = DeploymentArgs(init_kwargs, deploy_kwargs, resolved_config)
2352-
self.benchmark_metrics = benchmark_metrics
2353-
if resolved_config is not None:
2354-
self.acceleration_configs = resolved_config.get("acceleration_configs")
2354+
self.deployment_args = DeploymentArgs(
2355+
init_kwargs, deploy_kwargs, metadata_config.resolved_config
2356+
)
2357+
self.benchmark_metrics = metadata_config.benchmark_metrics
2358+
self.acceleration_configs = metadata_config.acceleration_configs

src/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1284,7 +1284,7 @@ def wrapped_f(*args, **kwargs):
12841284
break
12851285
elif isinstance(res, dict):
12861286
keys = list(res.keys())
1287-
if "Instance Rate" not in keys[-1]:
1287+
if len(keys) == 0 or "Instance Rate" not in keys[-1]:
12881288
f.cache_clear()
12891289
elif len(res[keys[1]]) > len(res[keys[-1]]):
12901290
del res[keys[-1]]

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,6 +1932,7 @@ def test_model_display_benchmark_metrics(
19321932
model = JumpStartModel(model_id=model_id)
19331933

19341934
model.display_benchmark_metrics()
1935+
model.display_benchmark_metrics(instance_type="g5.12xlarge")
19351936

19361937
@mock.patch("sagemaker.jumpstart.model.get_init_kwargs")
19371938
@mock.patch("sagemaker.jumpstart.utils.verify_model_region_and_return_specs")

tests/unit/sagemaker/jumpstart/test_types.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,8 +1372,7 @@ def test_deployment_config_metadata():
13721372

13731373
deployment_config_metadata = DeploymentConfigMetadata(
13741374
jumpstart_config.config_name,
1375-
jumpstart_config.benchmark_metrics,
1376-
jumpstart_config.resolved_config,
1375+
jumpstart_config,
13771376
JumpStartModelInitKwargs(
13781377
model_id=specs.model_id,
13791378
model_data=INIT_KWARGS.get("model_data"),

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,7 @@ def get_base_deployment_configs_metadata(
378378
configs.append(
379379
DeploymentConfigMetadata(
380380
config_name=config_name,
381-
benchmark_metrics=jumpstart_config.benchmark_metrics,
382-
resolved_config=jumpstart_config.resolved_config,
381+
metadata_config=jumpstart_config,
383382
init_kwargs=get_mock_init_kwargs(
384383
get_base_spec_with_prototype_configs().model_id, config_name
385384
),

0 commit comments

Comments
 (0)