Skip to content

Commit a275880

Browse files
SSRraymondRaymond Liu
authored andcommitted
feature: Decouple model.right_size() from model registry (aws#3688)
Co-authored-by: Raymond Liu <[email protected]>
1 parent 084af58 commit a275880

File tree

5 files changed

+502
-23
lines changed

5 files changed

+502
-23
lines changed

src/sagemaker/inference_recommender/inference_recommender_mixin.py

+33-5
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class Phase:
3838
"""
3939

4040
def __init__(self, duration_in_seconds: int, initial_number_of_users: int, spawn_rate: int):
41-
"""Initialze a `Phase`"""
41+
"""Initialize a `Phase`"""
4242
self.to_json = {
4343
"DurationInSeconds": duration_in_seconds,
4444
"InitialNumberOfUsers": initial_number_of_users,
@@ -53,7 +53,7 @@ class ModelLatencyThreshold:
5353
"""
5454

5555
def __init__(self, percentile: str, value_in_milliseconds: int):
56-
"""Initialze a `ModelLatencyThreshold`"""
56+
"""Initialize a `ModelLatencyThreshold`"""
5757
self.to_json = {"Percentile": percentile, "ValueInMilliseconds": value_in_milliseconds}
5858

5959

@@ -79,6 +79,12 @@ def right_size(
7979
):
8080
"""Recommends an instance type for a SageMaker or BYOC model.
8181
82+
Create a SageMaker ``Model`` or use a registered ``ModelPackage``,
83+
to start an Inference Recommender job.
84+
85+
The name of the created model is accessible in the ``name`` field of
86+
this ``Model`` after right_size returns.
87+
8288
Args:
8389
sample_payload_url (str): The S3 path where the sample payload is stored.
8490
supported_content_types: (list[str]): The supported MIME types for the input data.
@@ -119,8 +125,6 @@ def right_size(
119125
sagemaker.model.Model: A SageMaker ``Model`` object. See
120126
:func:`~sagemaker.model.Model` for full details.
121127
"""
122-
if not isinstance(self, sagemaker.model.ModelPackage):
123-
raise ValueError("right_size() is currently only supported with a registered model")
124128

125129
if not framework and self._framework():
126130
framework = INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING.get(self._framework(), framework)
@@ -149,12 +153,36 @@ def right_size(
149153

150154
self._init_sagemaker_session_if_does_not_exist()
151155

156+
if isinstance(self, sagemaker.model.Model) and not isinstance(
157+
self, sagemaker.model.ModelPackage
158+
):
159+
primary_container_def = self.prepare_container_def()
160+
if not self.name:
161+
self._ensure_base_name_if_needed(
162+
image_uri=primary_container_def["Image"],
163+
script_uri=self.source_dir,
164+
model_uri=self.model_data,
165+
)
166+
self._set_model_name_if_needed()
167+
168+
create_model_args = dict(
169+
name=self.name,
170+
role=self.role,
171+
container_defs=None,
172+
primary_container=primary_container_def,
173+
vpc_config=self.vpc_config,
174+
enable_network_isolation=self.enable_network_isolation(),
175+
)
176+
LOGGER.warning("Attempting to create new model with name %s", self.name)
177+
self.sagemaker_session.create_model(**create_model_args)
178+
152179
ret_name = self.sagemaker_session.create_inference_recommendations_job(
153180
role=self.role,
154181
job_name=job_name,
155182
job_type=job_type,
156183
job_duration_in_seconds=job_duration_in_seconds,
157-
model_package_version_arn=self.model_package_arn,
184+
model_name=self.name,
185+
model_package_version_arn=getattr(self, "model_package_arn", None),
158186
framework=framework,
159187
framework_version=framework_version,
160188
sample_payload_url=sample_payload_url,

src/sagemaker/session.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -4820,6 +4820,7 @@ def _create_inference_recommendations_job_request(
48204820
framework: str,
48214821
sample_payload_url: str,
48224822
supported_content_types: List[str],
4823+
model_name: str = None,
48234824
model_package_version_arn: str = None,
48244825
job_duration_in_seconds: int = None,
48254826
job_type: str = "Default",
@@ -4843,6 +4844,7 @@ def _create_inference_recommendations_job_request(
48434844
framework (str): The machine learning framework of the Image URI.
48444845
sample_payload_url (str): The S3 path where the sample payload is stored.
48454846
supported_content_types (List[str]): The supported MIME types for the input data.
4847+
model_name (str): Name of the Amazon SageMaker ``Model`` to be used.
48464848
model_package_version_arn (str): The Amazon Resource Name (ARN) of a
48474849
versioned model package.
48484850
job_duration_in_seconds (int): The maximum job duration that a job
@@ -4890,10 +4892,15 @@ def _create_inference_recommendations_job_request(
48904892
"RoleArn": role,
48914893
"InputConfig": {
48924894
"ContainerConfig": containerConfig,
4893-
"ModelPackageVersionArn": model_package_version_arn,
48944895
},
48954896
}
48964897

4898+
request.get("InputConfig").update(
4899+
{"ModelPackageVersionArn": model_package_version_arn}
4900+
if model_package_version_arn
4901+
else {"ModelName": model_name}
4902+
)
4903+
48974904
if job_description:
48984905
request["JobDescription"] = job_description
48994906
if job_duration_in_seconds:
@@ -4918,6 +4925,7 @@ def create_inference_recommendations_job(
49184925
supported_content_types: List[str],
49194926
job_name: str = None,
49204927
job_type: str = "Default",
4928+
model_name: str = None,
49214929
model_package_version_arn: str = None,
49224930
job_duration_in_seconds: int = None,
49234931
nearest_model_name: str = None,
@@ -4938,6 +4946,7 @@ def create_inference_recommendations_job(
49384946
You must grant sufficient permissions to this role.
49394947
sample_payload_url (str): The S3 path where the sample payload is stored.
49404948
supported_content_types (List[str]): The supported MIME types for the input data.
4949+
model_name (str): Name of the Amazon SageMaker ``Model`` to be used.
49414950
model_package_version_arn (str): The Amazon Resource Name (ARN) of a
49424951
versioned model package.
49434952
job_name (str): The name of the job being run.
@@ -4964,6 +4973,12 @@ def create_inference_recommendations_job(
49644973
str: The name of the job created. In the form of `SMPYTHONSDK-<timestamp>`
49654974
"""
49664975

4976+
if model_name is None and model_package_version_arn is None:
4977+
raise ValueError("Please provide either model_name or model_package_version_arn.")
4978+
4979+
if model_name is not None and model_package_version_arn is not None:
4980+
raise ValueError("Please provide either model_name or model_package_version_arn.")
4981+
49674982
if not job_name:
49684983
unique_tail = uuid.uuid4()
49694984
job_name = "SMPYTHONSDK-" + str(unique_tail)
@@ -4972,6 +4987,7 @@ def create_inference_recommendations_job(
49724987
create_inference_recommendations_job_request = (
49734988
self._create_inference_recommendations_job_request(
49744989
role=role,
4990+
model_name=model_name,
49754991
model_package_version_arn=model_package_version_arn,
49764992
job_name=job_name,
49774993
job_type=job_type,

tests/integ/test_inference_recommender.py

+181
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import pytest
1818

19+
from sagemaker.model import Model
1920
from sagemaker.sklearn.model import SKLearnModel, SKLearnPredictor
2021
from sagemaker.utils import unique_name_from_base
2122
from tests.integ import DATA_DIR
@@ -154,6 +155,120 @@ def advanced_right_sized_model(sagemaker_session, cpu_instance_type):
154155
)
155156

156157

158+
@pytest.fixture(scope="module")
159+
def default_right_sized_unregistered_model(sagemaker_session, cpu_instance_type):
160+
with timeout(minutes=45):
161+
try:
162+
ir_job_name = unique_name_from_base("test-ir-right-size-job-name")
163+
model_data = sagemaker_session.upload_data(path=IR_SKLEARN_MODEL)
164+
payload_data = sagemaker_session.upload_data(path=IR_SKLEARN_PAYLOAD)
165+
166+
iam_client = sagemaker_session.boto_session.client("iam")
167+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
168+
169+
sklearn_model = SKLearnModel(
170+
model_data=model_data,
171+
role=role_arn,
172+
entry_point=IR_SKLEARN_ENTRY_POINT,
173+
framework_version=IR_SKLEARN_FRAMEWORK_VERSION,
174+
)
175+
176+
return (
177+
sklearn_model.right_size(
178+
job_name=ir_job_name,
179+
sample_payload_url=payload_data,
180+
supported_content_types=IR_SKLEARN_CONTENT_TYPE,
181+
supported_instance_types=[cpu_instance_type],
182+
framework=IR_SKLEARN_FRAMEWORK,
183+
log_level="Quiet",
184+
),
185+
ir_job_name,
186+
)
187+
except Exception:
188+
sagemaker_session.delete_model(ModelName=sklearn_model.name)
189+
190+
191+
@pytest.fixture(scope="module")
192+
def advanced_right_sized_unregistered_model(sagemaker_session, cpu_instance_type):
193+
with timeout(minutes=45):
194+
try:
195+
model_data = sagemaker_session.upload_data(path=IR_SKLEARN_MODEL)
196+
payload_data = sagemaker_session.upload_data(path=IR_SKLEARN_PAYLOAD)
197+
198+
iam_client = sagemaker_session.boto_session.client("iam")
199+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
200+
201+
sklearn_model = SKLearnModel(
202+
model_data=model_data,
203+
role=role_arn,
204+
entry_point=IR_SKLEARN_ENTRY_POINT,
205+
framework_version=IR_SKLEARN_FRAMEWORK_VERSION,
206+
)
207+
208+
hyperparameter_ranges = [
209+
{
210+
"instance_types": CategoricalParameter([cpu_instance_type]),
211+
"TEST_PARAM": CategoricalParameter(
212+
["TEST_PARAM_VALUE_1", "TEST_PARAM_VALUE_2"]
213+
),
214+
}
215+
]
216+
217+
phases = [
218+
Phase(duration_in_seconds=300, initial_number_of_users=2, spawn_rate=2),
219+
Phase(duration_in_seconds=300, initial_number_of_users=14, spawn_rate=2),
220+
]
221+
222+
model_latency_thresholds = [
223+
ModelLatencyThreshold(percentile="P95", value_in_milliseconds=100)
224+
]
225+
226+
return sklearn_model.right_size(
227+
sample_payload_url=payload_data,
228+
supported_content_types=IR_SKLEARN_CONTENT_TYPE,
229+
framework=IR_SKLEARN_FRAMEWORK,
230+
job_duration_in_seconds=3600,
231+
hyperparameter_ranges=hyperparameter_ranges,
232+
phases=phases,
233+
model_latency_thresholds=model_latency_thresholds,
234+
max_invocations=100,
235+
max_tests=5,
236+
max_parallel_tests=5,
237+
log_level="Quiet",
238+
)
239+
240+
except Exception:
241+
sagemaker_session.delete_model(ModelName=sklearn_model.name)
242+
243+
244+
@pytest.fixture(scope="module")
245+
def default_right_sized_unregistered_base_model(sagemaker_session, cpu_instance_type):
246+
with timeout(minutes=45):
247+
try:
248+
ir_job_name = unique_name_from_base("test-ir-right-size-job-name")
249+
model_data = sagemaker_session.upload_data(path=IR_SKLEARN_MODEL)
250+
payload_data = sagemaker_session.upload_data(path=IR_SKLEARN_PAYLOAD)
251+
252+
iam_client = sagemaker_session.boto_session.client("iam")
253+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
254+
255+
model = Model(model_data=model_data, role=role_arn, entry_point=IR_SKLEARN_ENTRY_POINT)
256+
257+
return (
258+
model.right_size(
259+
job_name=ir_job_name,
260+
sample_payload_url=payload_data,
261+
supported_content_types=IR_SKLEARN_CONTENT_TYPE,
262+
supported_instance_types=[cpu_instance_type],
263+
framework=IR_SKLEARN_FRAMEWORK,
264+
log_level="Quiet",
265+
),
266+
ir_job_name,
267+
)
268+
except Exception:
269+
sagemaker_session.delete_model(ModelName=model.name)
270+
271+
157272
@pytest.mark.slow_test
158273
def test_default_right_size_and_deploy_registered_model_sklearn(
159274
default_right_sized_model, sagemaker_session
@@ -176,6 +291,72 @@ def test_default_right_size_and_deploy_registered_model_sklearn(
176291
predictor.delete_endpoint()
177292

178293

294+
@pytest.mark.slow_test
295+
def test_default_right_size_and_deploy_unregistered_model_sklearn(
296+
default_right_sized_unregistered_model, sagemaker_session
297+
):
298+
endpoint_name = unique_name_from_base("test-ir-right-size-default-unregistered-sklearn")
299+
300+
right_size_model, ir_job_name = default_right_sized_unregistered_model
301+
with timeout(minutes=45):
302+
try:
303+
right_size_model.predictor_cls = SKLearnPredictor
304+
predictor = right_size_model.deploy(endpoint_name=endpoint_name)
305+
306+
payload = pd.read_csv(IR_SKLEARN_DATA, header=None)
307+
308+
inference = predictor.predict(payload)
309+
assert inference is not None
310+
assert 26 == len(inference)
311+
finally:
312+
predictor.delete_model()
313+
predictor.delete_endpoint()
314+
315+
316+
@pytest.mark.slow_test
317+
def test_default_right_size_and_deploy_unregistered_base_model(
318+
default_right_sized_unregistered_base_model, sagemaker_session
319+
):
320+
endpoint_name = unique_name_from_base("test-ir-right-size-default-unregistered-base")
321+
322+
right_size_model, ir_job_name = default_right_sized_unregistered_base_model
323+
with timeout(minutes=45):
324+
try:
325+
right_size_model.predictor_cls = SKLearnPredictor
326+
predictor = right_size_model.deploy(endpoint_name=endpoint_name)
327+
328+
payload = pd.read_csv(IR_SKLEARN_DATA, header=None)
329+
330+
inference = predictor.predict(payload)
331+
assert inference is not None
332+
assert 26 == len(inference)
333+
finally:
334+
predictor.delete_model()
335+
predictor.delete_endpoint()
336+
337+
338+
@pytest.mark.slow_test
339+
def test_advanced_right_size_and_deploy_unregistered_model_sklearn(
340+
advanced_right_sized_unregistered_model, sagemaker_session
341+
):
342+
endpoint_name = unique_name_from_base("test-ir-right-size-advanced-sklearn")
343+
344+
right_size_model = advanced_right_sized_unregistered_model
345+
with timeout(minutes=45):
346+
try:
347+
right_size_model.predictor_cls = SKLearnPredictor
348+
predictor = right_size_model.deploy(endpoint_name=endpoint_name)
349+
350+
payload = pd.read_csv(IR_SKLEARN_DATA, header=None)
351+
352+
inference = predictor.predict(payload)
353+
assert inference is not None
354+
assert 26 == len(inference)
355+
finally:
356+
predictor.delete_model()
357+
predictor.delete_endpoint()
358+
359+
179360
@pytest.mark.slow_test
180361
def test_advanced_right_size_and_deploy_registered_model_sklearn(
181362
advanced_right_sized_model, sagemaker_session

0 commit comments

Comments
 (0)