Skip to content

Commit 9e5739a

Browse files
author
Raymond Liu
committed
address comments
1 parent 6b4e675 commit 9e5739a

File tree

5 files changed

+70
-31
lines changed

5 files changed

+70
-31
lines changed

src/sagemaker/inference_recommender/inference_recommender_mixin.py

Lines changed: 12 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ def right_size(
8080
):
8181
"""Recommends an instance type for a SageMaker or BYOC model.
8282
83+
Create a SageMaker ``Model`` or use a registered ``ModelPackage``,
84+
to start an Inference Recommender job.
85+
86+
The name of the created model is accessible in the ``name`` field of
87+
this ``Model`` after right_size returns.
88+
8389
Args:
8490
sample_payload_url (str): The S3 path where the sample payload is stored.
8591
supported_content_types: (list[str]): The supported MIME types for the input data.
@@ -148,36 +154,29 @@ def right_size(
148154

149155
self._init_sagemaker_session_if_does_not_exist()
150156

151-
self.temp_model_name = None
152157
if isinstance(self, sagemaker.model.Model) and not isinstance(
153158
self, sagemaker.model.ModelPackage
154159
):
155-
156-
unique_tail = uuid.uuid4()
157-
self.temp_model_name = "SMPYTHONSDK-" + str(unique_tail)
158-
160+
if not self.name:
161+
unique_tail = uuid.uuid4()
162+
self.name = "SageMaker-Model-RightSized-" + str(unique_tail)
159163
create_model_args = dict(
160-
name=self.temp_model_name,
164+
name=self.name,
161165
role=self.role,
162166
container_defs=None,
163167
primary_container=self.prepare_container_def(),
164168
vpc_config=self.vpc_config,
165169
enable_network_isolation=self.enable_network_isolation(),
166170
)
167-
print(
168-
f"Creating temporary model with name: {self.temp_model_name}"
169-
" for Inference Recommender.",
170-
flush=True,
171-
)
171+
LOGGER.warning("Attempting to create new model with name %s", self.name)
172172
self.sagemaker_session.create_model(**create_model_args)
173-
print("Temporary model created. Start to run Inference Recommender...", flush=True)
174173

175174
ret_name = self.sagemaker_session.create_inference_recommendations_job(
176175
role=self.role,
177176
job_name=job_name,
178177
job_type=job_type,
179178
job_duration_in_seconds=job_duration_in_seconds,
180-
model_name=self.temp_model_name,
179+
model_name=self.name,
181180
model_package_version_arn=getattr(self, "model_package_arn", None),
182181
framework=framework,
183182
framework_version=framework_version,
@@ -199,15 +198,6 @@ def right_size(
199198
"InferenceRecommendations"
200199
)
201200

202-
if self.temp_model_name is not None:
203-
print(
204-
f"Deleting temporary model with name: {self.temp_model_name} "
205-
"for Inference Recommender.",
206-
flush=True,
207-
)
208-
self.sagemaker_session.delete_model(self.temp_model_name)
209-
self.temp_model_name = None
210-
print("Delete complete.")
211201
return self
212202

213203
def _update_params(

src/sagemaker/session.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4975,13 +4975,12 @@ def create_inference_recommendations_job(
49754975

49764976
if model_name is None and model_package_version_arn is None:
49774977
raise ValueError(
4978-
"Missing model_name and model_package_version_arn," " please provide one of them."
4978+
"Please provide either model_name or model_package_version_arn, not both."
49794979
)
49804980

49814981
if model_name is not None and model_package_version_arn is not None:
49824982
raise ValueError(
4983-
"Please provide either model_name or model_package_version_arn"
4984-
" should be provided, not both."
4983+
"Please provide either model_name or model_package_version_arn, not both."
49854984
)
49864985

49874986
if not job_name:

tests/integ/test_inference_recommender.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import pytest
1818

19+
from sagemaker.model import Model
1920
from sagemaker.sklearn.model import SKLearnModel, SKLearnPredictor
2021
from sagemaker.utils import unique_name_from_base
2122
from tests.integ import DATA_DIR
@@ -184,7 +185,7 @@ def default_right_sized_unregistered_model(sagemaker_session, cpu_instance_type)
184185
ir_job_name,
185186
)
186187
except Exception:
187-
sagemaker_session.delete_model(ModelName=sklearn_model.temp_model_name)
188+
sagemaker_session.delete_model(ModelName=sklearn_model.name)
188189

189190

190191
@pytest.fixture(scope="module")
@@ -237,7 +238,35 @@ def advanced_right_sized_unregistered_model(sagemaker_session, cpu_instance_type
237238
)
238239

239240
except Exception:
240-
sagemaker_session.delete_model(ModelName=sklearn_model.temp_model_name)
241+
sagemaker_session.delete_model(ModelName=sklearn_model.name)
242+
243+
244+
@pytest.fixture(scope="module")
245+
def default_right_sized_unregistered_base_model(sagemaker_session, cpu_instance_type):
246+
with timeout(minutes=45):
247+
try:
248+
ir_job_name = unique_name_from_base("test-ir-right-size-job-name")
249+
model_data = sagemaker_session.upload_data(path=IR_SKLEARN_MODEL)
250+
payload_data = sagemaker_session.upload_data(path=IR_SKLEARN_PAYLOAD)
251+
252+
iam_client = sagemaker_session.boto_session.client("iam")
253+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
254+
255+
model = Model(model_data=model_data, role=role_arn, entry_point=IR_SKLEARN_ENTRY_POINT)
256+
257+
return (
258+
model.right_size(
259+
job_name=ir_job_name,
260+
sample_payload_url=payload_data,
261+
supported_content_types=IR_SKLEARN_CONTENT_TYPE,
262+
supported_instance_types=[cpu_instance_type],
263+
framework=IR_SKLEARN_FRAMEWORK,
264+
log_level="Quiet",
265+
),
266+
ir_job_name,
267+
)
268+
except Exception:
269+
sagemaker_session.delete_model(ModelName=model.name)
241270

242271

243272
@pytest.mark.slow_test
@@ -284,6 +313,28 @@ def test_default_right_size_and_deploy_unregistered_model_sklearn(
284313
predictor.delete_endpoint()
285314

286315

316+
@pytest.mark.slow_test
317+
def test_default_right_size_and_deploy_unregistered_base_model(
318+
default_right_sized_unregistered_base_model, sagemaker_session
319+
):
320+
endpoint_name = unique_name_from_base("test-ir-right-size-default-unregistered-base")
321+
322+
right_size_model, ir_job_name = default_right_sized_unregistered_base_model
323+
with timeout(minutes=45):
324+
try:
325+
right_size_model.predictor_cls = SKLearnPredictor
326+
predictor = right_size_model.deploy(endpoint_name=endpoint_name)
327+
328+
payload = pd.read_csv(IR_SKLEARN_DATA, header=None)
329+
330+
inference = predictor.predict(payload)
331+
assert inference is not None
332+
assert 26 == len(inference)
333+
finally:
334+
predictor.delete_model()
335+
predictor.delete_endpoint()
336+
337+
287338
@pytest.mark.slow_test
288339
def test_advanced_right_size_and_deploy_unregistered_model_sklearn(
289340
advanced_right_sized_unregistered_model, sagemaker_session

tests/unit/sagemaker/inference_recommender/test_inference_recommender_mixin.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
IR_SUPPORTED_CONTENT_TYPES = ["text/csv"]
2727
IR_JOB_NAME = "SMPYTHONSDK-1234567891"
2828
IR_SAMPLE_INSTANCE_TYPE = "ml.c5.xlarge"
29-
IR_MODEL_NAME = "SMPYTHONSDK-sample-unique-uuid"
29+
IR_MODEL_NAME = "SageMaker-Model-RightSized-sample-unique-uuid"
3030

3131
IR_SAMPLE_LIST_OF_INSTANCES_HYPERPARAMETER_RANGES = [
3232
{
@@ -186,7 +186,6 @@ def test_right_size_default_with_model_name_successful(sagemaker_session, model)
186186
framework=IR_SAMPLE_FRAMEWORK,
187187
)
188188

189-
# assert that the create model api has been called with default parameters
190189
assert sagemaker_session.create_model.called_with(
191190
name=IR_MODEL_NAME,
192191
role=IR_ROLE_ARN,

tests/unit/test_session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3396,7 +3396,7 @@ def test_create_inference_recommendations_job_advanced_model_name_happy(sagemake
33963396
def test_create_inference_recommendations_job_missing_model_name_and_pkg(sagemaker_session):
33973397
with pytest.raises(
33983398
ValueError,
3399-
match="Missing model_name and model_package_version_arn, please provide one of them.",
3399+
match="Please provide either model_name or model_package_version_arn, not both.",
34003400
):
34013401
sagemaker_session.create_inference_recommendations_job(
34023402
role=IR_ROLE_ARN,
@@ -3415,7 +3415,7 @@ def test_create_inference_recommendations_job_missing_model_name_and_pkg(sagemak
34153415
def test_create_inference_recommendations_job_provided_model_name_and_pkg(sagemaker_session):
34163416
with pytest.raises(
34173417
ValueError,
3418-
match="Please provide either model_name or model_package_version_arn should be provided, not both.",
3418+
match="Please provide either model_name or model_package_version_arn, not both.",
34193419
):
34203420
sagemaker_session.create_inference_recommendations_job(
34213421
role=IR_ROLE_ARN,

0 commit comments

Comments
 (0)