Skip to content

Commit d5c3b09

Browse files
gwang111Raymond Liu
and
Raymond Liu
committed
feature: right_size() for inference recommender (aws#792)
* feature: right_sizing() for inf rec * update imports and documentation * updated mixin * one more unit test * leverage more fixtures * one more unit test * check args changes and tests * naming change * fix doc strings, infer job_type * add sphynx doc change * integration testing * add integration tests for default case with sklearn * should delete model package only at the very end of tests * update integ test for defalut job * add integ test for advanced job * cleanup integ tests * update role * add quiet integ test * remove trailing spaces * fix doc string * fix logging format * reformatting * update sphinx doc * refactoring our error handling * refactor error handling * remove unused params from doc string * refactor error handling and integ tests * fix integ tests * switch role * refactor timeout and cleanup * fix flake8 failure * refactor check logic * remove unnecessary logging * make jobName truly unique Co-authored-by: Gary Wang <[email protected]> Co-authored-by: Raymond Liu <[email protected]>
1 parent 25997eb commit d5c3b09

File tree

13 files changed

+1122
-6
lines changed

13 files changed

+1122
-6
lines changed

doc/api/inference/model.rst

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Model
55
:members:
66
:undoc-members:
77
:show-inheritance:
8+
:inherited-members:
89

910
.. autoclass:: sagemaker.model.FrameworkModel
1011
:members:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes for using Inference Recommender with Amazon SageMaker."""
14+
from __future__ import absolute_import
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Placeholder docstring"""
14+
from __future__ import absolute_import
15+
16+
import logging
17+
18+
from typing import List, Dict, Optional
19+
20+
import sagemaker
21+
22+
from sagemaker.parameter import CategoricalParameter
23+
24+
INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING = {
25+
"xgboost": "XGBOOST",
26+
"sklearn": "SAGEMAKER-SCIKIT-LEARN",
27+
"pytorch": "PYTORCH",
28+
"tensorflow": "TENSORFLOW",
29+
"mxnet": "MXNET",
30+
}
31+
32+
LOGGER = logging.getLogger("sagemaker")
33+
34+
35+
class Phase:
36+
"""Used to store phases of a traffic pattern to perform endpoint load testing.
37+
38+
Required for an Advanced Inference Recommendations Job
39+
"""
40+
41+
def __init__(self, duration_in_seconds: int, initial_number_of_users: int, spawn_rate: int):
42+
"""Initialze a `Phase`"""
43+
self.to_json = {
44+
"DurationInSeconds": duration_in_seconds,
45+
"InitialNumberOfUsers": initial_number_of_users,
46+
"SpawnRate": spawn_rate,
47+
}
48+
49+
50+
class ModelLatencyThreshold:
51+
"""Used to store inference request/response latency to perform endpoint load testing.
52+
53+
Required for an Advanced Inference Recommendations Job
54+
"""
55+
56+
def __init__(self, percentile: str, value_in_milliseconds: int):
57+
"""Initialze a `ModelLatencyThreshold`"""
58+
self.to_json = {"Percentile": percentile, "ValueInMilliseconds": value_in_milliseconds}
59+
60+
61+
class InferenceRecommenderMixin:
62+
"""A mixin class for SageMaker ``Inference Recommender`` that will be extended by ``Model``"""
63+
64+
def right_size(
65+
self,
66+
sample_payload_url: str = None,
67+
supported_content_types: List[str] = None,
68+
supported_instance_types: List[str] = None,
69+
job_name: str = None,
70+
framework: str = None,
71+
job_duration_in_seconds: int = None,
72+
hyperparameter_ranges: List[Dict[str, CategoricalParameter]] = None,
73+
phases: List[Phase] = None,
74+
traffic_type: str = None,
75+
max_invocations: int = None,
76+
model_latency_thresholds: List[ModelLatencyThreshold] = None,
77+
max_tests: int = None,
78+
max_parallel_tests: int = None,
79+
log_level: Optional[str] = "Verbose",
80+
):
81+
"""Recommends an instance type for a SageMaker or BYOC model.
82+
83+
Args:
84+
sample_payload_url (str): The S3 path where the sample payload is stored.
85+
supported_content_types: (list[str]): The supported MIME types for the input data.
86+
supported_instance_types (list[str]): A list of the instance types that this model
87+
is expected to work on. (default: None).
88+
job_name (str): The name of the Inference Recommendations Job. (default: None).
89+
framework (str): The machine learning framework of the Image URI.
90+
Only required to specify if you bring your own custom containers (default: None).
91+
job_duration_in_seconds (int): The maximum job duration that a job can run for.
92+
(default: None).
93+
hyperparameter_ranges (list[Dict[str, sagemaker.parameter.CategoricalParameter]]):
94+
Specifies the hyper parameters to be used during endpoint load tests.
95+
`instance_type` must be specified as a hyperparameter range.
96+
`env_vars` can be specified as an optional hyperparameter range. (default: None).
97+
Example::
98+
99+
hyperparameter_ranges = [{
100+
'instance_types': CategoricalParameter(['ml.c5.xlarge', 'ml.c5.2xlarge']),
101+
'OMP_NUM_THREADS': CategoricalParameter(['1', '2', '3', '4'])
102+
}]
103+
104+
phases (list[Phase]): Specifies the criteria for increasing load
105+
during endpoint load tests. (default: None).
106+
traffic_type (str): Specifies the traffic type that matches the phases. (default: None).
107+
max_invocations (str): defines invocation limit for endpoint load tests (default: None).
108+
model_latency_thresholds (list[ModelLatencyThreshold]): defines the response latency
109+
thresholds for endpoint load tests (default: None).
110+
max_tests (int): restricts how many endpoints are allowed to be
111+
spun up for this job (default: None).
112+
max_parallel_tests (int): restricts how many concurrent endpoints
113+
this job is allowed to spin up (default: None).
114+
log_level (str): specifies the inline output when waiting for right_size to complete
115+
(default: "Verbose").
116+
117+
Returns:
118+
sagemaker.model.Model: A SageMaker ``Model`` object. See
119+
:func:`~sagemaker.model.Model` for full details.
120+
"""
121+
if not isinstance(self, sagemaker.model.ModelPackage):
122+
raise ValueError("right_size() is currently only supported with a registered model")
123+
124+
if not framework and self._framework():
125+
framework = INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING.get(self._framework, framework)
126+
127+
framework_version = self._get_framework_version()
128+
129+
endpoint_configurations = self._convert_to_endpoint_configurations_json(
130+
hyperparameter_ranges=hyperparameter_ranges
131+
)
132+
traffic_pattern = self._convert_to_traffic_pattern_json(
133+
traffic_type=traffic_type, phases=phases
134+
)
135+
stopping_conditions = self._convert_to_stopping_conditions_json(
136+
max_invocations=max_invocations, model_latency_thresholds=model_latency_thresholds
137+
)
138+
resource_limit = self._convert_to_resource_limit_json(
139+
max_tests=max_tests, max_parallel_tests=max_parallel_tests
140+
)
141+
142+
if endpoint_configurations or traffic_pattern or stopping_conditions or resource_limit:
143+
LOGGER.info("Advance Job parameters were specified. Running Advanced job...")
144+
job_type = "Advanced"
145+
else:
146+
LOGGER.info("Advance Job parameters were not specified. Running Default job...")
147+
job_type = "Default"
148+
149+
self._init_sagemaker_session_if_does_not_exist()
150+
151+
ret_name = self.sagemaker_session.create_inference_recommendations_job(
152+
role=self.role,
153+
job_name=job_name,
154+
job_type=job_type,
155+
job_duration_in_seconds=job_duration_in_seconds,
156+
model_package_version_arn=self.model_package_arn,
157+
framework=framework,
158+
framework_version=framework_version,
159+
sample_payload_url=sample_payload_url,
160+
supported_content_types=supported_content_types,
161+
supported_instance_types=supported_instance_types,
162+
endpoint_configurations=endpoint_configurations,
163+
traffic_pattern=traffic_pattern,
164+
stopping_conditions=stopping_conditions,
165+
resource_limit=resource_limit,
166+
)
167+
168+
self.inference_recommender_job_results = (
169+
self.sagemaker_session.wait_for_inference_recommendations_job(
170+
ret_name, log_level=log_level
171+
)
172+
)
173+
self.inference_recommendations = self.inference_recommender_job_results.get(
174+
"InferenceRecommendations"
175+
)
176+
177+
return self
178+
179+
def _check_inference_recommender_args(
180+
self,
181+
instance_type=None,
182+
initial_instance_count=None,
183+
accelerator_type=None,
184+
serverless_inference_config=None,
185+
async_inference_config=None,
186+
):
187+
"""Validates that Inference Recommendation parameters can be used in `model.deploy()`
188+
189+
Args:
190+
instance_type (str): The initial number of instances to run
191+
in the ``Endpoint`` created from this ``Model``. If not using
192+
serverless inference or the model has not called ``right_size()``,
193+
then it need to be a number larger or equals
194+
to 1 (default: None)
195+
initial_instance_count (int):The EC2 instance type to deploy this Model to.
196+
For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
197+
serverless inference or the model has not called ``right_size()``,
198+
then it is required to deploy a model.
199+
(default: None)
200+
accelerator_type (str): whether accelerator_type has been passed into `model.deploy()`.
201+
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig)):
202+
whether serverless_inference_config has been passed into `model.deploy()`.
203+
async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig):
204+
whether async_inference_config has been passed into `model.deploy()`.
205+
206+
Returns:
207+
(string, int) or None: Top instance_type and associated initial_instance_count
208+
if self.inference_recommender_job_results has been generated. Otherwise, return None.
209+
"""
210+
if accelerator_type:
211+
raise ValueError("accelerator_type is not compatible with right_size().")
212+
if instance_type or initial_instance_count:
213+
LOGGER.warning(
214+
"instance_type or initial_instance_count specified."
215+
"Overriding right_size() recommendations."
216+
)
217+
return None
218+
if async_inference_config:
219+
LOGGER.warning(
220+
"async_inference_config is specified. Overriding right_size() recommendations."
221+
)
222+
return None
223+
if serverless_inference_config:
224+
LOGGER.warning(
225+
"serverless_inference_config is specified. Overriding right_size() recommendations."
226+
)
227+
return None
228+
229+
instance_type = self.inference_recommendations[0]["EndpointConfiguration"]["InstanceType"]
230+
initial_instance_count = self.inference_recommendations[0]["EndpointConfiguration"][
231+
"InitialInstanceCount"
232+
]
233+
return (instance_type, initial_instance_count)
234+
235+
def _convert_to_endpoint_configurations_json(
236+
self, hyperparameter_ranges: List[Dict[str, CategoricalParameter]]
237+
):
238+
"""Bundle right_size() parameters into an endpoint configuration for Advanced job"""
239+
if not hyperparameter_ranges:
240+
return None
241+
242+
endpoint_configurations_to_json = []
243+
for parameter_range in hyperparameter_ranges:
244+
if not parameter_range.get("instance_types"):
245+
raise ValueError("instance_type must be defined as a hyperparameter_range")
246+
parameter_range = parameter_range.copy()
247+
instance_types = parameter_range.get("instance_types").values
248+
parameter_range.pop("instance_types")
249+
250+
for instance_type in instance_types:
251+
parameter_ranges = []
252+
for name, param in parameter_range.items():
253+
as_json = param.as_json_range(name)
254+
as_json["Value"] = as_json.pop("Values")
255+
parameter_ranges.append(as_json)
256+
endpoint_configurations_to_json.append(
257+
{
258+
"EnvironmentParameterRanges": {
259+
"CategoricalParameterRanges": parameter_ranges
260+
},
261+
"InstanceType": instance_type,
262+
}
263+
)
264+
265+
return endpoint_configurations_to_json
266+
267+
def _convert_to_traffic_pattern_json(self, traffic_type: str, phases: List[Phase]):
268+
"""Bundle right_size() parameters into a traffic pattern for Advanced job"""
269+
if not phases:
270+
return None
271+
return {
272+
"Phases": [phase.to_json for phase in phases],
273+
"TrafficType": traffic_type if traffic_type else "PHASES",
274+
}
275+
276+
def _convert_to_resource_limit_json(self, max_tests: int, max_parallel_tests: int):
277+
"""Bundle right_size() parameters into a resource limit for Advanced job"""
278+
if not max_tests and not max_parallel_tests:
279+
return None
280+
return {
281+
"MaxNumberOfTests": max_tests,
282+
"MaxParallelOfTests": max_parallel_tests,
283+
}
284+
285+
def _convert_to_stopping_conditions_json(
286+
self, max_invocations: int, model_latency_thresholds: List[ModelLatencyThreshold]
287+
):
288+
"""Bundle right_size() parameters into stopping conditions for Advanced job"""
289+
if not max_invocations and not model_latency_thresholds:
290+
return None
291+
return {
292+
"MaxInvocations": max_invocations,
293+
"ModelLatencyThresholds": [threshold.to_json for threshold in model_latency_thresholds],
294+
}

src/sagemaker/model.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from sagemaker.workflow import is_pipeline_variable
4949
from sagemaker.workflow.entities import PipelineVariable
5050
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
51+
from sagemaker.inference_recommender.inference_recommender_mixin import InferenceRecommenderMixin
5152

5253
LOGGER = logging.getLogger("sagemaker")
5354

@@ -83,7 +84,7 @@ def delete_model(self, *args, **kwargs) -> None:
8384
SAGEMAKER_OUTPUT_LOCATION = "sagemaker_s3_output"
8485

8586

86-
class Model(ModelBase):
87+
class Model(ModelBase, InferenceRecommenderMixin):
8788
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
8889

8990
def __init__(
@@ -279,6 +280,8 @@ def __init__(
279280
self._is_compiled_model = False
280281
self._compilation_job_name = None
281282
self._is_edge_packaged_model = False
283+
self.inference_recommender_job_results = None
284+
self.inference_recommendations = None
282285
self._enable_network_isolation = enable_network_isolation
283286
self.model_kms_key = model_kms_key
284287
self.image_config = image_config
@@ -1050,11 +1053,13 @@ def deploy(
10501053
Args:
10511054
initial_instance_count (int): The initial number of instances to run
10521055
in the ``Endpoint`` created from this ``Model``. If not using
1053-
serverless inference, then it need to be a number larger or equals
1056+
serverless inference or the model has not called ``right_size()``,
1057+
then it need to be a number larger or equals
10541058
to 1 (default: None)
10551059
instance_type (str): The EC2 instance type to deploy this Model to.
10561060
For example, 'ml.p2.xlarge', or 'local' for local mode. If not using
1057-
serverless inference, then it is required to deploy a model.
1061+
serverless inference or the model has not called ``right_size()``,
1062+
then it is required to deploy a model.
10581063
(default: None)
10591064
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
10601065
serializer object, used to encode data for an inference endpoint
@@ -1118,6 +1123,18 @@ def deploy(
11181123
is not None. Otherwise, return None.
11191124
"""
11201125
removed_kwargs("update_endpoint", kwargs)
1126+
1127+
if self.inference_recommender_job_results:
1128+
inference_recommendation = self._check_inference_recommender_args(
1129+
instance_type,
1130+
initial_instance_count,
1131+
accelerator_type,
1132+
serverless_inference_config,
1133+
async_inference_config,
1134+
)
1135+
if inference_recommendation:
1136+
instance_type, initial_instance_count = inference_recommendation
1137+
11211138
self._init_sagemaker_session_if_does_not_exist(instance_type)
11221139

11231140
tags = add_jumpstart_tags(

src/sagemaker/session.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import time
2222
import typing
2323
import warnings
24+
import uuid
2425
from typing import List, Dict, Any, Sequence, Optional
2526

2627
import boto3
@@ -4810,7 +4811,8 @@ def create_inference_recommendations_job(
48104811
"""
48114812

48124813
if not job_name:
4813-
job_name = "SMPYTHONSDK-" + str(round(time.time()))
4814+
unique_tail = uuid.uuid4()
4815+
job_name = "SMPYTHONSDK-" + str(unique_tail)
48144816
job_description = "#python-sdk-create"
48154817

48164818
create_inference_recommendations_job_request = (

0 commit comments

Comments
 (0)