Skip to content

Commit aefdcb7

Browse files
authored
feat: jumpstart model estimator classes (#3796)
1 parent 5a5b3d8 commit aefdcb7

File tree

79 files changed

+11392
-1881
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+11392
-1881
lines changed

src/sagemaker/accept_types.py

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module is for SageMaker accept types."""
14+
from __future__ import absolute_import
15+
from typing import List, Optional
16+
17+
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
18+
19+
20+
def retrieve_options(
21+
region: Optional[str] = None,
22+
model_id: Optional[str] = None,
23+
model_version: Optional[str] = None,
24+
tolerate_vulnerable_model: bool = False,
25+
tolerate_deprecated_model: bool = False,
26+
) -> List[str]:
27+
"""Retrieves the supported accept types for the model matching the given arguments.
28+
29+
Args:
30+
region (str): The AWS Region for which to retrieve the supported accept types.
31+
Defaults to ``None``.
32+
model_id (str): The model ID of the model for which to
33+
retrieve the supported accept types. (Default: None).
34+
model_version (str): The version of the model for which to retrieve the
35+
supported accept types. (Default: None).
36+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
37+
specifications should be tolerated (exception not raised). If False, raises an
38+
exception if the script used by this version of the model has dependencies with known
39+
security vulnerabilities. (Default: False).
40+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
41+
(exception not raised). False if these models should raise an exception.
42+
(Default: False).
43+
Returns:
44+
list: The supported accept types to use for the model.
45+
46+
Raises:
47+
ValueError: If the combination of arguments specified is not supported.
48+
"""
49+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
50+
raise ValueError(
51+
"Must specify JumpStart `model_id` and `model_version` when retrieving accept types."
52+
)
53+
54+
return artifacts._retrieve_supported_accept_types(
55+
model_id,
56+
model_version,
57+
region,
58+
tolerate_vulnerable_model,
59+
tolerate_deprecated_model,
60+
)
61+
62+
63+
def retrieve_default(
64+
region: Optional[str] = None,
65+
model_id: Optional[str] = None,
66+
model_version: Optional[str] = None,
67+
tolerate_vulnerable_model: bool = False,
68+
tolerate_deprecated_model: bool = False,
69+
) -> str:
70+
"""Retrieves the default accept type for the model matching the given arguments.
71+
72+
Args:
73+
region (str): The AWS Region for which to retrieve the default accept type.
74+
Defaults to ``None``.
75+
model_id (str): The model ID of the model for which to
76+
retrieve the default accept type. (Default: None).
77+
model_version (str): The version of the model for which to retrieve the
78+
default accept type. (Default: None).
79+
tolerate_vulnerable_model (bool): True if vulnerable versions of model
80+
specifications should be tolerated (exception not raised). If False, raises an
81+
exception if the script used by this version of the model has dependencies with known
82+
security vulnerabilities. (Default: False).
83+
tolerate_deprecated_model (bool): True if deprecated models should be tolerated
84+
(exception not raised). False if these models should raise an exception.
85+
(Default: False).
86+
Returns:
87+
str: The default accept type to use for the model.
88+
89+
Raises:
90+
ValueError: If the combination of arguments specified is not supported.
91+
"""
92+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
93+
raise ValueError(
94+
"Must specify JumpStart `model_id` and `model_version` when retrieving accept types."
95+
)
96+
97+
return artifacts._retrieve_default_accept_type(
98+
model_id,
99+
model_version,
100+
region,
101+
tolerate_vulnerable_model,
102+
tolerate_deprecated_model,
103+
)

0 commit comments

Comments
 (0)