Skip to content

Commit c850dc5

Browse files
committed
change: cleanup code
1 parent 26c5169 commit c850dc5

File tree

11 files changed

+145
-76
lines changed

11 files changed

+145
-76
lines changed

src/sagemaker/image_uris.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ def retrieve(
8282
model_version (str): Version of the JumpStart model for which to retrieve the
8383
image URI (default: None).
8484
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
85-
not thrown). False if these models should throw an exception. (Default: None).
85+
not raised). False if these models should raise an exception. (Default: None).
8686
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
87-
not thrown). False if these models should throw an exception. (Default: None).
87+
not raised). False if these models should raise an exception. (Default: None).
8888
8989
Returns:
9090
str: the ECR URI for the corresponding SageMaker Docker image.

src/sagemaker/jumpstart/artifacts.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
from sagemaker import image_uris
1717
from sagemaker.jumpstart.constants import (
1818
JUMPSTART_DEFAULT_REGION_NAME,
19-
INFERENCE,
20-
TRAINING,
19+
JumpStartScriptScope,
2120
ModelFramework,
2221
VariableScope,
2322
)
@@ -77,9 +76,9 @@ def _retrieve_image_uri(
7776
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
7877
A configuration class for the SageMaker Training Compiler.
7978
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
80-
not thrown). False if these models should throw an exception.
79+
not raised). False if these models should raise an exception.
8180
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
82-
not thrown). False if these models should throw an exception.
81+
not raised). False if these models should raise an exception.
8382
8483
Returns:
8584
str: the ECR URI for the corresponding SageMaker Docker image.
@@ -103,9 +102,9 @@ def _retrieve_image_uri(
103102
tolerate_deprecated_model=tolerate_deprecated_model,
104103
)
105104

106-
if image_scope == INFERENCE:
105+
if image_scope == JumpStartScriptScope.INFERENCE.value:
107106
ecr_specs = model_specs.hosting_ecr_specs
108-
elif image_scope == TRAINING:
107+
elif image_scope == JumpStartScriptScope.TRAINING.value:
109108
assert model_specs.training_ecr_specs is not None
110109
ecr_specs = model_specs.training_ecr_specs
111110

@@ -133,7 +132,7 @@ def _retrieve_image_uri(
133132
base_framework_version_override = ecr_specs.framework_version
134133
version_override = ecr_specs.huggingface_transformers_version
135134

136-
if image_scope == TRAINING:
135+
if image_scope == JumpStartScriptScope.TRAINING.value:
137136
return image_uris.get_training_image_uri(
138137
region=region,
139138
framework=ecr_specs.framework,
@@ -183,9 +182,9 @@ def _retrieve_model_uri(
183182
Valid values: "training" and "inference".
184183
region (str): Region for which to retrieve model S3 URI.
185184
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
186-
not thrown). False if these models should throw an exception.
185+
not raised). False if these models should raise an exception.
187186
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
188-
not thrown). False if these models should throw an exception.
187+
not raised). False if these models should raise an exception.
189188
Returns:
190189
str: the model artifact S3 URI for the corresponding model.
191190
@@ -208,9 +207,9 @@ def _retrieve_model_uri(
208207
tolerate_deprecated_model=tolerate_deprecated_model,
209208
)
210209

211-
if model_scope == INFERENCE:
210+
if model_scope == JumpStartScriptScope.INFERENCE.value:
212211
model_artifact_key = model_specs.hosting_artifact_key
213-
elif model_scope == TRAINING:
212+
elif model_scope == JumpStartScriptScope.TRAINING.value:
214213
assert model_specs.training_artifact_key is not None
215214
model_artifact_key = model_specs.training_artifact_key
216215

@@ -240,9 +239,9 @@ def _retrieve_script_uri(
240239
Valid values: "training" and "inference".
241240
region (str): Region for which to retrieve model script S3 URI.
242241
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
243-
not thrown). False if these models should throw an exception.
242+
not raised). False if these models should raise an exception.
244243
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
245-
not thrown). False if these models should throw an exception.
244+
not raised). False if these models should raise an exception.
246245
Returns:
247246
str: the model script URI for the corresponding model.
248247
@@ -265,9 +264,9 @@ def _retrieve_script_uri(
265264
tolerate_deprecated_model=tolerate_deprecated_model,
266265
)
267266

268-
if script_scope == INFERENCE:
267+
if script_scope == JumpStartScriptScope.INFERENCE.value:
269268
model_script_key = model_specs.hosting_script_key
270-
elif script_scope == TRAINING:
269+
elif script_scope == JumpStartScriptScope.TRAINING.value:
271270
assert model_specs.training_script_key is not None
272271
model_script_key = model_specs.training_script_key
273272

src/sagemaker/jumpstart/constants.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,21 @@
116116

117117
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
118118

119-
INFERENCE = "inference"
120-
TRAINING = "training"
121-
SUPPORTED_JUMPSTART_SCOPES = set([INFERENCE, TRAINING])
122119

123120
INFERENCE_ENTRYPOINT_SCRIPT_NAME = "inference.py"
124121
TRAINING_ENTRYPOINT_SCRIPT_NAME = "transfer_learning.py"
125122

126123

124+
class JumpStartScriptScope(str, Enum):
125+
"""Enum class for JumpStart script scopes."""
126+
127+
INFERENCE = "inference"
128+
TRAINING = "training"
129+
130+
131+
SUPPORTED_JUMPSTART_SCOPES = set(scope.value for scope in JumpStartScriptScope)
132+
133+
127134
class ModelFramework(str, Enum):
128135
"""Enum class for JumpStart model framework.
129136

src/sagemaker/jumpstart/exceptions.py

Lines changed: 51 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,46 +12,79 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module stores exceptions related to SageMaker JumpStart."""
1414

15+
from __future__ import absolute_import
1516
from typing import List, Optional
1617

18+
from sagemaker.jumpstart.constants import JumpStartScriptScope
19+
1720

1821
class VulnerableJumpStartModelError(Exception):
19-
"""Exception raised for errors with vulnerable JumpStart models."""
22+
"""Exception raised when trying to access a JumpStart model specs flagged as vulnerable.
23+
24+
Raise this exception only if the scope of attributes accessed in the specifications have
25+
vulnerabilities. For example, a model training script may have vulnerabilities, but not
26+
the hosting scripts. In such a case, raise a ``VulnerableJumpStartModelError`` only when
27+
accessing the training specifications.
28+
"""
2029

2130
def __init__(
2231
self,
2332
model_id: Optional[str] = None,
2433
version: Optional[str] = None,
2534
vulnerabilities: Optional[List[str]] = None,
26-
inference: Optional[bool] = None,
35+
scope: Optional[JumpStartScriptScope] = None,
2736
message: Optional[str] = None,
2837
):
38+
"""Instantiates VulnerableJumpStartModelError exception.
39+
40+
Args:
41+
model_id (Optional[str]): model id of vulnerable JumpStart model.
42+
(Default: None).
43+
version (Optional[str]): version of vulnerable JumpStart model.
44+
(Default: None).
45+
vulnerabilities (Optional[List[str]]): vulnerabilities associated with
46+
model. (Default: None).
47+
48+
"""
2949
if message:
3050
self.message = message
3151
else:
32-
if None in [model_id, version, vulnerabilities, inference]:
52+
if None in [model_id, version, vulnerabilities, scope]:
3353
raise ValueError(
34-
"Must specify `model_id`, `version`, `vulnerabilities`, "
35-
"and inference arguments."
54+
"Must specify `model_id`, `version`, `vulnerabilities`, " "and scope arguments."
3655
)
37-
if inference is True:
56+
if scope == JumpStartScriptScope.INFERENCE:
3857
self.message = (
39-
f"JumpStart model '{model_id}' and version '{version}' has at least 1 "
40-
"vulnerable dependency in the inference scripts. "
41-
f"List of vulnerabilities: {', '.join(vulnerabilities)}"
58+
f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore
59+
"has at least 1 vulnerable dependency in the inference script. "
60+
"Please try targetting a higher version of the model. "
61+
f"List of vulnerabilities: {', '.join(vulnerabilities)}" # type: ignore
4262
)
43-
else:
63+
elif scope == JumpStartScriptScope.TRAINING:
4464
self.message = (
45-
f"JumpStart model '{model_id}' and version '{version}' has at least 1 "
46-
"vulnerable dependency in the training scripts. "
47-
f"List of vulnerabilities: {', '.join(vulnerabilities)}"
65+
f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore
66+
"has at least 1 vulnerable dependency in the training script. "
67+
"Please try targetting a higher version of the model. "
68+
f"List of vulnerabilities: {', '.join(vulnerabilities)}" # type: ignore
69+
)
70+
else:
71+
raise NotImplementedError(
72+
"Unsupported scope for VulnerableJumpStartModelError: " # type: ignore
73+
f"'{scope.value}'"
4874
)
4975

5076
super().__init__(self.message)
5177

5278

5379
class DeprecatedJumpStartModelError(Exception):
54-
"""Exception raised for errors with deprecated JumpStart models."""
80+
"""Exception raised when trying to access a JumpStart model deprecated specifications.
81+
82+
A deprecated specification for a JumpStart model does not mean the whole model is
83+
deprecated. There may be more recent specifications available for this model. For
84+
example, all specification before version ``2.0.0`` may be deprecated, in such a
85+
case, the SDK would raise this exception only when specifications ``1.*`` are
86+
accessed.
87+
"""
5588

5689
def __init__(
5790
self,
@@ -64,6 +97,9 @@ def __init__(
6497
else:
6598
if None in [model_id, version]:
6699
raise ValueError("Must specify `model_id` and `version` arguments.")
67-
self.message = f"JumpStart model '{model_id}' and version '{version}' is deprecated."
100+
self.message = (
101+
f"Version '{version}' of JumpStart model '{model_id}' is deprecated. "
102+
"Please try targetting a higher version of the model."
103+
)
68104

69105
super().__init__(self.message)

src/sagemaker/jumpstart/utils.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
DeprecatedJumpStartModelError,
2222
VulnerableJumpStartModelError,
2323
)
24-
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartVersionedModelId
24+
from sagemaker.jumpstart.types import (
25+
JumpStartModelHeader,
26+
JumpStartModelSpecs,
27+
JumpStartVersionedModelId,
28+
)
2529

2630

2731
def get_jumpstart_launched_regions_message() -> str:
@@ -149,12 +153,9 @@ def verify_model_region_and_return_specs(
149153
region: str,
150154
tolerate_vulnerable_model: Optional[bool] = None,
151155
tolerate_deprecated_model: Optional[bool] = None,
152-
):
156+
) -> JumpStartModelSpecs:
153157
"""Verifies that an acceptable model_id, version, scope, and region combination is provided.
154158
155-
If the scope is not supported, the model id/region/version has no spec, or the model is vulnerable
156-
or deprecated, an exception will be raised.
157-
158159
Args:
159160
model_id (Optional[str]): model id of the JumpStart model to verify and
160161
obtains specs.
@@ -163,10 +164,19 @@ def verify_model_region_and_return_specs(
163164
scope (Optional[str]): scope of the JumpStart model to verify.
164165
region (Optional[str]): region of the JumpStart model to verify and
165166
obtains specs.
166-
tolerate_vulnerable_model (Optional[bool]): True if vulnerable models should be tolerated (exception
167-
not thrown). False if these models should throw an exception. (Default: None).
168-
tolerate_deprecated_model (Optional[bool]): True if deprecated models should be tolerated (exception
169-
not thrown). False if these models should throw an exception. (Default: None).
167+
tolerate_vulnerable_model (Optional[bool]): True if vulnerable models should be tolerated
168+
(exception not raised). False if these models should raise an exception.
169+
(Default: None).
170+
tolerate_deprecated_model (Optional[bool]): True if deprecated models should be tolerated
171+
(exception not raised). False if these models should raise an exception.
172+
(Default: None).
173+
174+
175+
Raises:
176+
ValueError: If the combination of arguments specified is not supported.
177+
NotImplementedError: If the scope is not supported.
178+
VulnerableJumpStartModelError: If the model is vulnerable.
179+
DeprecatedJumpStartModelError: If the model is deprecated.
170180
"""
171181

172182
if tolerate_vulnerable_model is None:
@@ -182,15 +192,22 @@ def verify_model_region_and_return_specs(
182192
)
183193

184194
if scope not in constants.SUPPORTED_JUMPSTART_SCOPES:
185-
raise ValueError(
186-
f"JumpStart models only support scopes: {', '.join(constants.SUPPORTED_JUMPSTART_SCOPES)}."
195+
raise NotImplementedError(
196+
"JumpStart models only support scopes: "
197+
f"{', '.join(constants.SUPPORTED_JUMPSTART_SCOPES)}."
187198
)
188199

200+
assert model_id is not None
201+
assert version is not None
202+
189203
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(
190204
region=region, model_id=model_id, version=version
191205
)
192206

193-
if scope == constants.TRAINING and not model_specs.training_supported:
207+
if (
208+
scope == constants.JumpStartScriptScope.TRAINING.value
209+
and not model_specs.training_supported
210+
):
194211
raise ValueError(
195212
f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training."
196213
)
@@ -199,27 +216,27 @@ def verify_model_region_and_return_specs(
199216
raise DeprecatedJumpStartModelError(model_id=model_id, version=version)
200217

201218
if (
202-
scope == constants.INFERENCE
219+
scope == constants.JumpStartScriptScope.INFERENCE.value
203220
and model_specs.inference_vulnerable
204221
and not tolerate_vulnerable_model
205222
):
206223
raise VulnerableJumpStartModelError(
207224
model_id=model_id,
208225
version=version,
209226
vulnerabilities=model_specs.inference_vulnerabilities,
210-
inference=True,
227+
scope=constants.JumpStartScriptScope.INFERENCE,
211228
)
212229

213230
if (
214-
scope == constants.TRAINING
231+
scope == constants.JumpStartScriptScope.TRAINING.value
215232
and model_specs.training_vulnerable
216233
and not tolerate_vulnerable_model
217234
):
218235
raise VulnerableJumpStartModelError(
219236
model_id=model_id,
220237
version=version,
221238
vulnerabilities=model_specs.training_vulnerabilities,
222-
inference=False,
239+
scope=constants.JumpStartScriptScope.TRAINING,
223240
)
224241

225242
return model_specs

src/sagemaker/model_uris.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ def retrieve(
4242
model_scope (str): The model type, i.e. what it is used for.
4343
Valid values: "training" and "inference".
4444
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
45-
not thrown). False if these models should throw an exception. (Default: None).
45+
not raised). False if these models should raise an exception. (Default: None).
4646
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
47-
not thrown). False if these models should throw an exception. (Default: None).
47+
not raised). False if these models should raise an exception. (Default: None).
4848
Returns:
4949
str: the model artifact S3 URI for the corresponding model.
5050

src/sagemaker/script_uris.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,9 @@ def retrieve(
4242
script_scope (str): The script type, i.e. what it is used for.
4343
Valid values: "training" and "inference".
4444
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
45-
not thrown). False if these models should throw an exception. (Default: None).
45+
not raised). False if these models should raise an exception. (Default: None).
4646
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
47-
not thrown). False if these models should throw an exception. (Default: None).
47+
not raised). False if these models should raise an exception. (Default: None).
4848
Returns:
4949
str: the model script URI for the corresponding model.
5050

tests/unit/sagemaker/image_uris/jumpstart/test_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_jumpstart_common_image_uri(
9696
)
9797
patched_verify_model_region_and_return_specs.assert_called_once()
9898

99-
with pytest.raises(ValueError):
99+
with pytest.raises(NotImplementedError):
100100
image_uris.retrieve(
101101
framework=None,
102102
region="us-west-2",

0 commit comments

Comments
 (0)