Skip to content

Commit 89fa436

Browse files
cj-zhangJoseph Zhang
authored andcommitted
feature: add benchmarking classes for llm benchmarking project (aws#1435)
* Add inference benchmarking classes, types, and some function definitions. * Presets and tests, class definitions. * Add Benchmarker tests and fix formatting. * Update classes and presets, add sample payload uploader. * Change BenchmarkJob class to use @DataClass * Fix ValueError message. * Use more list comprehension and pythonic practices. * Add list to BenchmarkJob. Minor fixes to validations. * Update preset values/names, docstrings, use Model.create(). * Format and remove local testing change * Fix outdated IR functions in Session. Implement more class functions and fix bugs caught during e2e testing using XGB model. * Remove test notebook. * Format dataframes, make list functions return df, iron out bugs in sample payload uploader. * Add docstrings. * Set default instance count for Benchmark.deploy() to 1. * Revert traffic pattern durations to 120s (current min supported by IR). * Wire new LLM benchmarking project fields to create_benchmark_job flow. Replace prints with logging. * Fixes for bugs caught during e2e testing. Add processing for detailed metrics DF and fix to_model() bugs. * Change list_benchmark_jobs() to return a list. * Add metric data CSV processing into dict. Minor style/bugfixes. * Implement compare_benchmarks(). Improve dataframe formatting and docstrings. * Fix formatting. --------- Co-authored-by: Joseph Zhang <[email protected]>
1 parent 2b35717 commit 89fa436

21 files changed

+2000
-37
lines changed
+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""Placeholder docstring"""
2+
3+
from __future__ import absolute_import
4+
import logging
5+
import pandas as pd
6+
from sagemaker.benchmarking.config_grid import ConfigGrid
7+
from sagemaker.benchmarking.preset.benchmarks.methodology import Methodology
8+
from sagemaker.benchmarking.preset.stopping_conditions import PresetStoppingConditions
9+
from sagemaker.benchmarking.traffic_pattern_config import TrafficPatternConfig
10+
from sagemaker.benchmarking.preset.traffic_pattern import PresetTrafficPatternConfig
11+
from sagemaker.benchmarking.payload.sample_payload_uploader import SamplePayloadUploader
12+
13+
__all__ = (
14+
"ConfigGrid",
15+
"Methodology",
16+
"PresetStoppingConditions",
17+
"TrafficPatternConfig",
18+
"PresetTrafficPatternConfig",
19+
"SamplePayloadUploader",
20+
)
21+
22+
23+
pd.set_option("display.max_rows", 500)
24+
pd.set_option("display.max_columns", 500)
25+
pd.set_option("display.width", 150)
26+
pd.set_option("display.max_colwidth", None)
27+
28+
logger = logging.getLogger(__name__)
29+
logger.propagate = False
30+
logger.setLevel(logging.DEBUG)
31+
streamHandler = logging.StreamHandler()
32+
streamHandler.setFormatter(logging.Formatter("[Benchmarker][%(levelname)s] %(message)s"))
33+
logger.addHandler(streamHandler)
+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Holds the InferenceBenchmark class."""
14+
from __future__ import absolute_import
15+
16+
from datetime import datetime
17+
from dataclasses import dataclass, field
18+
import logging
19+
import csv
20+
from typing import List, Dict, Optional, Union
21+
import pandas as pd
22+
from sagemaker.predictor import Predictor
23+
24+
from sagemaker import Model, Session
25+
from sagemaker.benchmarking.constants import COLUMN_PREFIXES_TO_TRIM
26+
from sagemaker.benchmarking.utils import get_benchmarks_output_csv
27+
28+
logger = LOGGER = logging.getLogger(__name__)
29+
30+
31+
@dataclass
32+
class InferenceBenchmark:
33+
"""Class definition for an individual benchmark."""
34+
35+
benchmark_id: Optional[str] = field(
36+
default=None,
37+
metadata={"help": "The benchmark ID which uniquely identifies each benchmark"},
38+
)
39+
endpoint_config: Optional[Dict[str, Union[str, int, Dict[str, int]]]] = field(
40+
default=None, metadata={"help": "Defines the endpoint configuration parameters"}
41+
)
42+
metrics: Optional[Dict[str, Union[int, float]]] = field(
43+
default=None, metadata={"help": "Metrics emitted during the benchmark"}
44+
)
45+
model_name: Optional[str] = field(
46+
default=None,
47+
metadata={"help": "The name of the model used by this job for benchmarking"},
48+
)
49+
model_configuration: Optional[Dict[str, Union[str, List[Dict[str, str]]]]] = field(
50+
default=None, metadata={"help": "Defines the model configuration"}
51+
)
52+
invocation_start_time: Optional[datetime] = field(
53+
default=None,
54+
metadata={"help": "A timestamp that shows when the benchmark started"},
55+
)
56+
invocation_end_time: Optional[datetime] = field(
57+
default=None,
58+
metadata={"help": "A timestamp that shows when the benchmark completed"},
59+
)
60+
role_arn: Optional[str] = field(
61+
default=None, metadata={"help": "Define the role for the endpoint"}
62+
)
63+
sagemaker_session: Optional[Session] = field(
64+
default=Session(), metadata={"help": "Define sagemaker session for execution"}
65+
)
66+
67+
def detailed_metrics_df(self) -> pd.DataFrame:
68+
"""Returns a dataframe with metrics at every concurrency level."""
69+
benchmark_results_output_location = get_benchmarks_output_csv(
70+
job_name=self.benchmark_id.split("/")[0], session=self.sagemaker_session
71+
)
72+
df = pd.read_csv(benchmark_results_output_location)
73+
# Get only rows for this benchmark & capitalize column headers for consistency.
74+
df = df.loc[df["RecommendationId"] == self.benchmark_id]
75+
df.columns = df.columns.str.replace("RecommendationId", "BenchmarkId")
76+
df.rename(columns=lambda x: x[0].upper() + x[1:], inplace=True)
77+
return df.sort_values(by="Concurrency")
78+
79+
def detailed_metrics_dict(self) -> List[Dict[str, str]]:
80+
"""Returns a dictionary with metrics at every concurrency level."""
81+
benchmark_results_output_location = get_benchmarks_output_csv(
82+
job_name=self.benchmark_id.split("/")[0], session=self.sagemaker_session
83+
)
84+
with open(benchmark_results_output_location, "r") as file:
85+
csv_reader = csv.DictReader(file)
86+
return list(csv_reader)
87+
88+
def key_metrics_df(self) -> pd.DataFrame:
89+
"""Returns a dataframe with metrics at the max concurrency level."""
90+
key_metrics = pd.json_normalize(self.key_metrics_dict())
91+
df = pd.DataFrame.from_dict(key_metrics)
92+
df.columns = df.columns.str.replace(COLUMN_PREFIXES_TO_TRIM, "", regex=True)
93+
return df
94+
95+
def key_metrics_dict(self) -> Dict:
96+
"""Returns a dictionary with metrics at the max concurrency level."""
97+
return {
98+
"BenchmarkId": self.benchmark_id,
99+
"ModelName": self.model_name,
100+
"EndpointConfig": self.endpoint_config,
101+
"Metrics": self.metrics,
102+
"ModelConfiguration": self.model_configuration,
103+
"InvocationStartTime": self.invocation_start_time,
104+
"InvocationEndTime": self.invocation_end_time,
105+
}
106+
107+
def to_model(
108+
self,
109+
role_arn: Optional[str] = None,
110+
predictor_cls: Optional[Predictor] = Predictor,
111+
) -> Model:
112+
"""Creates a Model from this benchmark.
113+
114+
Args:
115+
role_arn (str): The role for the endpoint
116+
predictor_cls (Predictor): A function to call to create a predictor.
117+
"""
118+
try:
119+
response = self.sagemaker_session.describe_model(name=self.model_name)
120+
if "PrimaryContainer" in response.keys():
121+
container = response.get("PrimaryContainer")
122+
elif "Containers" in response.keys():
123+
if len(response.get("Containers")) > 1:
124+
logger.warning(
125+
"More than one container found for the model, using the first one."
126+
)
127+
container = response.get("Containers")[0]
128+
else:
129+
raise ValueError("No containers defined for model {}".format(self.model_name))
130+
131+
return Model(
132+
role=role_arn or self.role_arn,
133+
image_uri=container.get("Image"),
134+
model_data=self._get_model_data_location(container=container),
135+
env=self._convert_model_configuration_to_env(),
136+
predictor_cls=predictor_cls,
137+
sagemaker_session=self.sagemaker_session,
138+
)
139+
except Exception as e:
140+
raise Exception("Failed to describe model with name {}".format(self.model_name)) from e
141+
142+
def _get_model_data_location(self, container: Dict) -> str:
143+
"""Returns S3 location of a container's model data.
144+
145+
Args:
146+
container (Dict): Container configuration from DescribeEndpoint() call.
147+
"""
148+
# ModelDataUrl will only point to compressed tar archives.
149+
# ModelDataSource can also include compressed artifacts, and defines compression type.
150+
if container.get("ModelDataUrl"):
151+
return container.get("ModelDataUrl")
152+
153+
return container.get("ModelDataSource")
154+
155+
def _convert_model_configuration_to_env(self) -> Dict[str, str]:
156+
"""Converts model configuration to env params."""
157+
if self.model_configuration is None:
158+
return {}
159+
160+
env_params = {
161+
e.get("Key"): e.get("Value")
162+
for e in self.model_configuration.get("EnvironmentParameters", [])
163+
}
164+
return env_params

0 commit comments

Comments
 (0)