-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: jumpstart instance types #3686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
c28134a
37ad9e4
7baec98
1da64e7
975def8
54e6754
af7639a
980ea50
0375b61
a1a5d8c
ca767f4
044fcb2
d3424ab
9bc9629
db599df
40d8680
4058e22
f9e2fb4
9a46fcb
d3c77aa
516da5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
"""Accessors to retrieve instance types.""" | ||
|
||
from __future__ import absolute_import | ||
|
||
import logging | ||
from typing import List, Optional | ||
|
||
from sagemaker.jumpstart import utils as jumpstart_utils | ||
from sagemaker.jumpstart import artifacts | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def retrieve_default( | ||
region=None, | ||
model_id=None, | ||
model_version=None, | ||
scope=None, | ||
tolerate_vulnerable_model: bool = False, | ||
tolerate_deprecated_model: bool = False, | ||
) -> Optional[str]: | ||
"""Retrieves the default instance type for the model matching the given arguments. | ||
|
||
Args: | ||
region (str): The AWS Region for which to retrieve the default instance type. | ||
Defaults to ``None``. | ||
model_id (str): The model ID of the model for which to | ||
retrieve the default instance type. (Default: None). | ||
model_version (str): The version of the model for which to retrieve the | ||
default instance type. (Default: None). | ||
scope (str): The model type, i.e. what it is used for. | ||
Valid values: "training" and "inference". | ||
tolerate_vulnerable_model (bool): True if vulnerable versions of model | ||
specifications should be tolerated (exception not raised). If False, raises an | ||
exception if the script used by this version of the model has dependencies with known | ||
security vulnerabilities. (Default: False). | ||
tolerate_deprecated_model (bool): True if deprecated models should be tolerated | ||
(exception not raised). False if these models should raise an exception. | ||
(Default: False). | ||
Returns: | ||
dict: The default instance type to use for the model. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why |
||
|
||
Raises: | ||
ValueError: If the combination of arguments specified is not supported. | ||
""" | ||
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): | ||
raise ValueError( | ||
"Must specify `model_id` and `model_version` when retrieving instance types." | ||
) | ||
|
||
if scope is None: | ||
raise ValueError("Must specify scope for instance types.") | ||
|
||
return artifacts._retrieve_default_instance_type( | ||
model_id, | ||
model_version, | ||
scope, | ||
region, | ||
tolerate_vulnerable_model, | ||
tolerate_deprecated_model, | ||
) | ||
|
||
|
||
def retrieve_supported( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question: how about I know that it could look like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can but that would break the pattern we have with the other JS utilities. I don't think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, could we compromise on simply: |
||
region=None, | ||
model_id=None, | ||
model_version=None, | ||
scope=None, | ||
tolerate_vulnerable_model: bool = False, | ||
tolerate_deprecated_model: bool = False, | ||
) -> Optional[List[str]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same question please. If no instance is available in a region, I would vote for either returning an empty list or raising an error. |
||
"""Retrieves the supported training instance types for the model matching the given arguments. | ||
|
||
Args: | ||
region (str): The AWS Region for which to retrieve the supported instance types. | ||
Defaults to ``None``. | ||
model_id (str): The model ID of the model for which to | ||
retrieve the supported instance types. (Default: None). | ||
model_version (str): The version of the model for which to retrieve the | ||
supported instance types. (Default: None). | ||
tolerate_vulnerable_model (bool): True if vulnerable versions of model | ||
specifications should be tolerated (exception not raised). If False, raises an | ||
exception if the script used by this version of the model has dependencies with known | ||
security vulnerabilities. (Default: False). | ||
tolerate_deprecated_model (bool): True if deprecated models should be tolerated | ||
(exception not raised). False if these models should raise an exception. | ||
(Default: False). | ||
Returns: | ||
dict: The supported instance types to use for the model. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. revise docstring please, this utility returns a list. |
||
|
||
Raises: | ||
ValueError: If the combination of arguments specified is not supported. | ||
""" | ||
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version): | ||
raise ValueError( | ||
"Must specify `model_id` and `model_version` when retrieving instance types." | ||
) | ||
|
||
if scope is None: | ||
raise ValueError("Must specify scope for instance types.") | ||
|
||
return artifacts._retrieve_supported_instance_type( | ||
model_id, | ||
model_version, | ||
scope, | ||
region, | ||
tolerate_vulnerable_model, | ||
tolerate_deprecated_model, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question: why is the return type
Optional
?If a model is not supported in a region because no instances are available, I would vote to raise an exception with an error message asking customers to use a different AWS region.