-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: Decouple model.right_size() from model registry #3688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
506806e
265015a
e1e9ebd
f2ed8f4
0f35930
3ed3d7a
160d5c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
|
||
import logging | ||
import re | ||
import uuid | ||
|
||
from typing import List, Dict, Optional | ||
import sagemaker | ||
|
@@ -38,7 +39,7 @@ class Phase: | |
""" | ||
|
||
def __init__(self, duration_in_seconds: int, initial_number_of_users: int, spawn_rate: int): | ||
"""Initialze a `Phase`""" | ||
"""Initialize a `Phase`""" | ||
self.to_json = { | ||
"DurationInSeconds": duration_in_seconds, | ||
"InitialNumberOfUsers": initial_number_of_users, | ||
|
@@ -53,7 +54,7 @@ class ModelLatencyThreshold: | |
""" | ||
|
||
def __init__(self, percentile: str, value_in_milliseconds: int): | ||
"""Initialze a `ModelLatencyThreshold`""" | ||
"""Initialize a `ModelLatencyThreshold`""" | ||
self.to_json = {"Percentile": percentile, "ValueInMilliseconds": value_in_milliseconds} | ||
|
||
|
||
|
@@ -79,6 +80,12 @@ def right_size( | |
): | ||
"""Recommends an instance type for a SageMaker or BYOC model. | ||
|
||
Create a SageMaker ``Model`` or use a registered ``ModelPackage``, | ||
to start an Inference Recommender job. | ||
|
||
The name of the created model is accessible in the ``name`` field of | ||
this ``Model`` after right_size returns. | ||
|
||
Args: | ||
sample_payload_url (str): The S3 path where the sample payload is stored. | ||
supported_content_types: (list[str]): The supported MIME types for the input data. | ||
|
@@ -119,8 +126,6 @@ def right_size( | |
sagemaker.model.Model: A SageMaker ``Model`` object. See | ||
:func:`~sagemaker.model.Model` for full details. | ||
""" | ||
if not isinstance(self, sagemaker.model.ModelPackage): | ||
raise ValueError("right_size() is currently only supported with a registered model") | ||
|
||
if not framework and self._framework(): | ||
framework = INFERENCE_RECOMMENDER_FRAMEWORK_MAPPING.get(self._framework(), framework) | ||
|
@@ -149,12 +154,30 @@ def right_size( | |
|
||
self._init_sagemaker_session_if_does_not_exist() | ||
|
||
if isinstance(self, sagemaker.model.Model) and not isinstance( | ||
self, sagemaker.model.ModelPackage | ||
): | ||
if not self.name: | ||
unique_tail = uuid.uuid4() | ||
self.name = "SageMaker-Model-RightSized-" + str(unique_tail) | ||
SSRraymond marked this conversation as resolved.
Show resolved
Hide resolved
|
||
create_model_args = dict( | ||
name=self.name, | ||
role=self.role, | ||
container_defs=None, | ||
primary_container=self.prepare_container_def(), | ||
vpc_config=self.vpc_config, | ||
enable_network_isolation=self.enable_network_isolation(), | ||
) | ||
LOGGER.warning("Attempting to create new model with name %s", self.name) | ||
self.sagemaker_session.create_model(**create_model_args) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we not just replace this whole block with this?
Do we need to create a temporary model name? Can we not just use original name? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
ret_name = self.sagemaker_session.create_inference_recommendations_job( | ||
role=self.role, | ||
job_name=job_name, | ||
job_type=job_type, | ||
job_duration_in_seconds=job_duration_in_seconds, | ||
model_package_version_arn=self.model_package_arn, | ||
model_name=self.name, | ||
model_package_version_arn=getattr(self, "model_package_arn", None), | ||
framework=framework, | ||
framework_version=framework_version, | ||
sample_payload_url=sample_payload_url, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
|
||
import pytest | ||
|
||
from sagemaker.model import Model | ||
from sagemaker.sklearn.model import SKLearnModel, SKLearnPredictor | ||
from sagemaker.utils import unique_name_from_base | ||
from tests.integ import DATA_DIR | ||
|
@@ -154,6 +155,120 @@ def advanced_right_sized_model(sagemaker_session, cpu_instance_type): | |
) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def default_right_sized_unregistered_model(sagemaker_session, cpu_instance_type): | ||
with timeout(minutes=45): | ||
try: | ||
ir_job_name = unique_name_from_base("test-ir-right-size-job-name") | ||
model_data = sagemaker_session.upload_data(path=IR_SKLEARN_MODEL) | ||
payload_data = sagemaker_session.upload_data(path=IR_SKLEARN_PAYLOAD) | ||
|
||
iam_client = sagemaker_session.boto_session.client("iam") | ||
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] | ||
|
||
sklearn_model = SKLearnModel( | ||
model_data=model_data, | ||
role=role_arn, | ||
entry_point=IR_SKLEARN_ENTRY_POINT, | ||
framework_version=IR_SKLEARN_FRAMEWORK_VERSION, | ||
) | ||
|
||
return ( | ||
sklearn_model.right_size( | ||
job_name=ir_job_name, | ||
sample_payload_url=payload_data, | ||
supported_content_types=IR_SKLEARN_CONTENT_TYPE, | ||
supported_instance_types=[cpu_instance_type], | ||
framework=IR_SKLEARN_FRAMEWORK, | ||
log_level="Quiet", | ||
), | ||
ir_job_name, | ||
) | ||
except Exception: | ||
sagemaker_session.delete_model(ModelName=sklearn_model.name) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def advanced_right_sized_unregistered_model(sagemaker_session, cpu_instance_type): | ||
with timeout(minutes=45): | ||
try: | ||
model_data = sagemaker_session.upload_data(path=IR_SKLEARN_MODEL) | ||
payload_data = sagemaker_session.upload_data(path=IR_SKLEARN_PAYLOAD) | ||
|
||
iam_client = sagemaker_session.boto_session.client("iam") | ||
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] | ||
|
||
sklearn_model = SKLearnModel( | ||
model_data=model_data, | ||
role=role_arn, | ||
entry_point=IR_SKLEARN_ENTRY_POINT, | ||
framework_version=IR_SKLEARN_FRAMEWORK_VERSION, | ||
) | ||
|
||
hyperparameter_ranges = [ | ||
{ | ||
"instance_types": CategoricalParameter([cpu_instance_type]), | ||
"TEST_PARAM": CategoricalParameter( | ||
["TEST_PARAM_VALUE_1", "TEST_PARAM_VALUE_2"] | ||
), | ||
} | ||
] | ||
|
||
phases = [ | ||
Phase(duration_in_seconds=300, initial_number_of_users=2, spawn_rate=2), | ||
Phase(duration_in_seconds=300, initial_number_of_users=14, spawn_rate=2), | ||
] | ||
|
||
model_latency_thresholds = [ | ||
ModelLatencyThreshold(percentile="P95", value_in_milliseconds=100) | ||
] | ||
|
||
return sklearn_model.right_size( | ||
sample_payload_url=payload_data, | ||
supported_content_types=IR_SKLEARN_CONTENT_TYPE, | ||
framework=IR_SKLEARN_FRAMEWORK, | ||
job_duration_in_seconds=3600, | ||
hyperparameter_ranges=hyperparameter_ranges, | ||
phases=phases, | ||
model_latency_thresholds=model_latency_thresholds, | ||
max_invocations=100, | ||
max_tests=5, | ||
max_parallel_tests=5, | ||
log_level="Quiet", | ||
) | ||
|
||
except Exception: | ||
sagemaker_session.delete_model(ModelName=sklearn_model.name) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def default_right_sized_unregistered_base_model(sagemaker_session, cpu_instance_type): | ||
with timeout(minutes=45): | ||
try: | ||
ir_job_name = unique_name_from_base("test-ir-right-size-job-name") | ||
model_data = sagemaker_session.upload_data(path=IR_SKLEARN_MODEL) | ||
payload_data = sagemaker_session.upload_data(path=IR_SKLEARN_PAYLOAD) | ||
|
||
iam_client = sagemaker_session.boto_session.client("iam") | ||
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"] | ||
|
||
model = Model(model_data=model_data, role=role_arn, entry_point=IR_SKLEARN_ENTRY_POINT) | ||
|
||
return ( | ||
model.right_size( | ||
job_name=ir_job_name, | ||
sample_payload_url=payload_data, | ||
supported_content_types=IR_SKLEARN_CONTENT_TYPE, | ||
supported_instance_types=[cpu_instance_type], | ||
framework=IR_SKLEARN_FRAMEWORK, | ||
log_level="Quiet", | ||
), | ||
ir_job_name, | ||
) | ||
except Exception: | ||
sagemaker_session.delete_model(ModelName=model.name) | ||
|
||
|
||
@pytest.mark.slow_test | ||
def test_default_right_size_and_deploy_registered_model_sklearn( | ||
default_right_sized_model, sagemaker_session | ||
|
@@ -176,6 +291,72 @@ def test_default_right_size_and_deploy_registered_model_sklearn( | |
predictor.delete_endpoint() | ||
|
||
|
||
@pytest.mark.slow_test | ||
def test_default_right_size_and_deploy_unregistered_model_sklearn( | ||
default_right_sized_unregistered_model, sagemaker_session | ||
): | ||
endpoint_name = unique_name_from_base("test-ir-right-size-default-unregistered-sklearn") | ||
|
||
right_size_model, ir_job_name = default_right_sized_unregistered_model | ||
with timeout(minutes=45): | ||
try: | ||
right_size_model.predictor_cls = SKLearnPredictor | ||
predictor = right_size_model.deploy(endpoint_name=endpoint_name) | ||
|
||
payload = pd.read_csv(IR_SKLEARN_DATA, header=None) | ||
|
||
inference = predictor.predict(payload) | ||
assert inference is not None | ||
assert 26 == len(inference) | ||
finally: | ||
predictor.delete_model() | ||
predictor.delete_endpoint() | ||
|
||
|
||
@pytest.mark.slow_test | ||
def test_default_right_size_and_deploy_unregistered_base_model( | ||
default_right_sized_unregistered_base_model, sagemaker_session | ||
): | ||
endpoint_name = unique_name_from_base("test-ir-right-size-default-unregistered-base") | ||
|
||
right_size_model, ir_job_name = default_right_sized_unregistered_base_model | ||
with timeout(minutes=45): | ||
try: | ||
right_size_model.predictor_cls = SKLearnPredictor | ||
predictor = right_size_model.deploy(endpoint_name=endpoint_name) | ||
|
||
payload = pd.read_csv(IR_SKLEARN_DATA, header=None) | ||
|
||
inference = predictor.predict(payload) | ||
assert inference is not None | ||
assert 26 == len(inference) | ||
finally: | ||
predictor.delete_model() | ||
predictor.delete_endpoint() | ||
|
||
|
||
@pytest.mark.slow_test | ||
def test_advanced_right_size_and_deploy_unregistered_model_sklearn( | ||
advanced_right_sized_unregistered_model, sagemaker_session | ||
): | ||
endpoint_name = unique_name_from_base("test-ir-right-size-advanced-sklearn") | ||
|
||
right_size_model = advanced_right_sized_unregistered_model | ||
with timeout(minutes=45): | ||
try: | ||
right_size_model.predictor_cls = SKLearnPredictor | ||
predictor = right_size_model.deploy(endpoint_name=endpoint_name) | ||
|
||
payload = pd.read_csv(IR_SKLEARN_DATA, header=None) | ||
|
||
inference = predictor.predict(payload) | ||
assert inference is not None | ||
assert 26 == len(inference) | ||
finally: | ||
predictor.delete_model() | ||
predictor.delete_endpoint() | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what about scenario:
? Do we need to cover this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to be very clear, I mean the generic Model case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think covering the SKLearn model case is good enough because SKLearn model inherits the model class There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm but a |
||
@pytest.mark.slow_test | ||
def test_advanced_right_size_and_deploy_registered_model_sklearn( | ||
advanced_right_sized_model, sagemaker_session | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we check and throw exception for
not instance of Model
andnot instance of ModelPackage
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think we need because Model is the base class for everything