Skip to content

Benchmark feature fixes tests #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
May 2, 2024
44 changes: 17 additions & 27 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from __future__ import absolute_import

from functools import lru_cache
from typing import Dict, List, Optional, Any, Union
import pandas as pd
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -49,6 +48,7 @@
get_metrics_from_deployment_configs,
add_instance_rate_stats_to_benchmark_metrics,
deployment_config_response_data,
_deployment_config_lru_cache,
)
from sagemaker.jumpstart.constants import JUMPSTART_LOGGER
from sagemaker.jumpstart.enums import JumpStartModelType
Expand Down Expand Up @@ -464,17 +464,14 @@ def benchmark_metrics(self) -> pd.DataFrame:
Returns:
Benchmark Metrics: Pandas DataFrame object.
"""
benchmark_metrics_data = self._get_deployment_configs_benchmarks_data(
self.config_name, self.instance_type
)
keys = list(benchmark_metrics_data.keys())
# Sort by Config Name and Instance Type column values
df = pd.DataFrame(benchmark_metrics_data).sort_values(by=[keys[1], keys[0]])
return df
df = pd.DataFrame(self._get_deployment_configs_benchmarks_data())
default_mask = df.apply(lambda row: any("Default" in str(val) for val in row), axis=1)
sorted_df = pd.concat([df[default_mask], df[~default_mask]])
return sorted_df

def display_benchmark_metrics(self) -> None:
def display_benchmark_metrics(self, *args, **kwargs) -> None:
"""Display deployment configs benchmark metrics."""
print(self.benchmark_metrics.to_markdown(index=False))
print(self.benchmark_metrics.to_markdown(index=False), *args, **kwargs)

def list_deployment_configs(self) -> List[Dict[str, Any]]:
"""List deployment configs for ``This`` model.
Expand Down Expand Up @@ -874,36 +871,29 @@ def register_deploy_wrapper(*args, **kwargs):

return model_package

@lru_cache
def _get_deployment_configs_benchmarks_data(
self, config_name: str, instance_type: str
) -> Dict[str, Any]:
@_deployment_config_lru_cache
def _get_deployment_configs_benchmarks_data(self) -> Dict[str, Any]:
"""Deployment configs benchmark metrics.

Args:
config_name (str): Name of selected deployment config.
instance_type (str): The selected Instance type.
Returns:
Dict[str, List[str]]: Deployment config benchmark data.
"""
return get_metrics_from_deployment_configs(
config_name,
instance_type,
self._get_deployment_configs(config_name, instance_type),
self._get_deployment_configs(None, None),
)

@lru_cache
@_deployment_config_lru_cache
def _get_deployment_configs(
self, selected_config_name: str, selected_instance_type: str
self, selected_config_name: Optional[str], selected_instance_type: Optional[str]
) -> List[DeploymentConfigMetadata]:
"""Retrieve deployment configs metadata.

Args:
selected_config_name (str): The name of the selected deployment config.
selected_instance_type (str): The selected instance type.
selected_config_name (Optional[str]): The name of the selected deployment config.
selected_instance_type (Optional[str]): The selected instance type.
"""
deployment_configs = []
if self._metadata_configs is None:
if not self._metadata_configs:
return deployment_configs

err = None
Expand Down Expand Up @@ -940,9 +930,9 @@ def _get_deployment_configs(
)
deployment_configs.append(deployment_config_metadata)

if err is not None and "is not authorized to perform: pricing:GetProducts" in err:
if err and err["Code"] == "AccessDeniedException":
error_message = "Instance rate metrics will be omitted. Reason: %s"
JUMPSTART_LOGGER.warning(error_message, err)
JUMPSTART_LOGGER.warning(error_message, err["Message"])

return deployment_configs

Expand Down
118 changes: 86 additions & 32 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# language governing permissions and limitations under the License.
"""This module contains utilities related to SageMaker JumpStart."""
from __future__ import absolute_import

import logging
import os
from functools import lru_cache, wraps
from typing import Any, Dict, List, Set, Optional, Tuple, Union
from urllib.parse import urlparse
import boto3
Expand Down Expand Up @@ -1023,7 +1025,9 @@ def get_jumpstart_configs(
raise ValueError(f"Unknown script scope: {scope}.")

if not config_names:
config_names = metadata_configs.configs.keys() if metadata_configs else []
config_names = (
metadata_configs.config_rankings.get("overall").rankings if metadata_configs else []
)

return (
{config_name: metadata_configs.configs[config_name] for config_name in config_names}
Expand All @@ -1035,39 +1039,38 @@ def get_jumpstart_configs(
def add_instance_rate_stats_to_benchmark_metrics(
region: str,
benchmark_metrics: Optional[Dict[str, List[JumpStartBenchmarkStat]]],
) -> Optional[Tuple[str, Dict[str, List[JumpStartBenchmarkStat]]]]:
) -> Optional[Tuple[Dict[str, str], Dict[str, List[JumpStartBenchmarkStat]]]]:
"""Adds instance types metric stats to the given benchmark_metrics dict.

Args:
region (str): AWS region.
benchmark_metrics (Optional[Dict[str, List[JumpStartBenchmarkStat]]]):
Returns:
Tuple[str, Dict[str, List[JumpStartBenchmarkStat]]]:
Contains Error message and metrics dict.
Optional[Tuple[Dict[str, str], Dict[str, List[JumpStartBenchmarkStat]]]]:
Contains Error and metrics.
"""
if benchmark_metrics is None:
if not benchmark_metrics:
return None

err_message = None
final_benchmark_metrics = {}
for instance_type, benchmark_metric_stats in benchmark_metrics.items():
instance_type = instance_type if instance_type.startswith("ml.") else f"ml.{instance_type}"

if not has_instance_rate_stat(benchmark_metric_stats) and err_message is None:
if not has_instance_rate_stat(benchmark_metric_stats) and not err_message:
try:
instance_type_rate = get_instance_rate_per_hour(
instance_type=instance_type, region=region
)

if benchmark_metric_stats:
benchmark_metric_stats.append(JumpStartBenchmarkStat(instance_type_rate))
else:
benchmark_metric_stats = [JumpStartBenchmarkStat(instance_type_rate)]
if not benchmark_metric_stats:
benchmark_metric_stats = []
benchmark_metric_stats.append(JumpStartBenchmarkStat(instance_type_rate))

final_benchmark_metrics[instance_type] = benchmark_metric_stats
except ClientError as e:
final_benchmark_metrics[instance_type] = benchmark_metric_stats
err_message = e.response["Error"]["Message"]
err_message = e.response["Error"]
except Exception: # pylint: disable=W0703
final_benchmark_metrics[instance_type] = benchmark_metric_stats
else:
Expand All @@ -1086,33 +1089,32 @@ def has_instance_rate_stat(benchmark_metric_stats: Optional[List[JumpStartBenchm
bool: Whether the benchmark metric stats contains instance rate metric stat.
"""
if benchmark_metric_stats is None:
return False

return True
for benchmark_metric_stat in benchmark_metric_stats:
if benchmark_metric_stat.name.lower() == "instance rate":
return True

return False


def get_metrics_from_deployment_configs(
default_config_name: str,
default_instance_type: str,
deployment_configs: List[DeploymentConfigMetadata],
deployment_configs: Optional[List[DeploymentConfigMetadata]],
) -> Dict[str, List[str]]:
"""Extracts benchmark metrics from deployment configs metadata.

Args:
default_config_name (str): The name of the default deployment config.
default_instance_type (str): The name of the default instance type.
deployment_configs (List[DeploymentConfigMetadata]): List of deployment configs metadata.
deployment_configs (Optional[List[DeploymentConfigMetadata]]):
List of deployment configs metadata.
Returns:
Dict[str, List[str]]: Deployment configs bench metrics dict.
"""
if not deployment_configs:
return {}

data = {"Instance Type": [], "Config Name": []}
instance_rate_data = {}

for deployment_config in deployment_configs:
for index, deployment_config in enumerate(deployment_configs):
benchmark_metrics = deployment_config.benchmark_metrics
if deployment_config.deployment_args is None or benchmark_metrics is None:
if not deployment_config.deployment_args or not benchmark_metrics:
continue

for inner_index, current_instance_type in enumerate(benchmark_metrics):
Expand All @@ -1121,8 +1123,8 @@ def get_metrics_from_deployment_configs(
data["Config Name"].append(deployment_config.deployment_config_name)
instance_type_to_display = (
f"{current_instance_type} (Default)"
if current_instance_type == default_instance_type
and default_config_name == deployment_config.deployment_config_name
if index == 0
and current_instance_type == deployment_config.deployment_args.default_instance_type
else current_instance_type
)
data["Instance Type"].append(instance_type_to_display)
Expand All @@ -1131,10 +1133,9 @@ def get_metrics_from_deployment_configs(
column_name = f"{metric.name} ({metric.unit})"

if metric.name.lower() == "instance rate":
if column_name in instance_rate_data:
instance_rate_data[column_name].append(metric.value)
else:
instance_rate_data[column_name] = [metric.value]
if column_name not in instance_rate_data:
instance_rate_data[column_name] = []
instance_rate_data[column_name].append(metric.value)
else:
if column_name not in data:
data[column_name] = []
Expand All @@ -1158,19 +1159,72 @@ def deployment_config_response_data(
List[Dict[str, Any]]: List of deployment config api response data.
"""
configs = []
if deployment_configs is None:
if not deployment_configs:
return configs

for deployment_config in deployment_configs:
deployment_config_json = deployment_config.to_json()
benchmark_metrics = deployment_config_json.get("BenchmarkMetrics")
if benchmark_metrics:
if benchmark_metrics and deployment_config.deployment_args:
deployment_config_json["BenchmarkMetrics"] = {
deployment_config.deployment_args.instance_type: benchmark_metrics.get(
deployment_config.deployment_args.instance_type
)
}

configs.append(deployment_config_json)

return configs


def _deployment_config_lru_cache(_func=None, *, maxsize: int = 128, typed: bool = False):
"""LRU cache for deployment configs."""

def has_instance_rate_metric(config: DeploymentConfigMetadata) -> bool:
"""Determines whether metadata config contains instance rate metric stat.

Args:
config (DeploymentConfigMetadata): Metadata config metadata.
Returns:
bool: Whether the metadata config contains instance rate metric stat.
"""
if config.benchmark_metrics is None:
return True
for benchmark_metric_stats in config.benchmark_metrics.values():
if not has_instance_rate_stat(benchmark_metric_stats):
return False
return True

def wrapper_cache(f):
f = lru_cache(maxsize=maxsize, typed=typed)(f)

@wraps(f)
def wrapped_f(*args, **kwargs):
res = f(*args, **kwargs)

# Clear cache on first call if
# - The output does not contain Instant rate metrics
# as this is caused by missing policy.
if f.cache_info().hits == 0 and f.cache_info().misses == 1:
if isinstance(res, list):
for item in res:
if isinstance(
item, DeploymentConfigMetadata
) and not has_instance_rate_metric(item):
f.cache_clear()
break
elif isinstance(res, dict):
keys = list(res.keys())
if "Instance Rate" not in keys[-1]:
f.cache_clear()
elif len(res[keys[1]]) > len(res[keys[-1]]):
del res[keys[-1]]
f.cache_clear()
return res

wrapped_f.cache_info = f.cache_info
wrapped_f.cache_clear = f.cache_clear
return wrapped_f

if _func is None:
return wrapper_cache
return wrapper_cache(_func)
13 changes: 10 additions & 3 deletions tests/unit/sagemaker/jumpstart/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,7 +1815,13 @@ def test_add_instance_rate_stats_to_benchmark_metrics_client_ex(
mock_get_instance_rate_per_hour,
):
mock_get_instance_rate_per_hour.side_effect = ClientError(
{"Error": {"Message": "is not authorized to perform: pricing:GetProducts"}}, "GetProducts"
{
"Error": {
"Message": "is not authorized to perform: pricing:GetProducts",
"Code": "AccessDenied",
},
},
"GetProducts",
)

err, out = utils.add_instance_rate_stats_to_benchmark_metrics(
Expand All @@ -1827,15 +1833,16 @@ def test_add_instance_rate_stats_to_benchmark_metrics_client_ex(
},
)

assert err == "is not authorized to perform: pricing:GetProducts"
assert err["Message"] == "is not authorized to perform: pricing:GetProducts"
assert err["Code"] == "AccessDenied"
for key in out:
assert len(out[key]) == 1


@pytest.mark.parametrize(
"stats, expected",
[
(None, False),
(None, True),
(
[JumpStartBenchmarkStat({"name": "Instance Rate", "unit": "USD/Hrs", "value": "3.76"})],
True,
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ def get_base_deployment_configs_metadata(
else get_base_spec_with_prototype_configs()
)
configs = []
for config_name, jumpstart_config in specs.inference_configs.configs.items():
for config_name in specs.inference_configs.config_rankings.get("overall").rankings:
jumpstart_config = specs.inference_configs.configs.get(config_name)
benchmark_metrics = jumpstart_config.benchmark_metrics

if benchmark_metrics:
Expand Down