Skip to content

Commit 09c97f4

Browse files
committed
change: cleanup code, docstrings
1 parent 4bf0299 commit 09c97f4

File tree

5 files changed

+67
-44
lines changed

5 files changed

+67
-44
lines changed

src/sagemaker/image_uris.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,13 @@ def retrieve(
8181
(default: None).
8282
model_version (str): Version of the JumpStart model for which to retrieve the
8383
image URI (default: None).
84-
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
85-
not raised). False if these models should raise an exception. (Default: None).
86-
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
87-
not raised). False if these models should raise an exception. (Default: None).
84+
tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications
85+
should be tolerated (exception not raised). False or None, raises an exception if
86+
the script used by this version of the model has dependencies with known security
87+
vulnerabilities. (Default: None).
88+
tolerate_deprecated_model (bool): True if deprecated versions of model specifications
89+
should be tolerated (exception not raised). False or None, raises an exception
90+
if the version of the model is deprecated. (Default: None).
8891
8992
Returns:
9093
str: the ECR URI for the corresponding SageMaker Docker image.

src/sagemaker/jumpstart/artifacts.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -75,18 +75,22 @@ def _retrieve_image_uri(
7575
distribution (dict): A dictionary with information on how to run distributed training
7676
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
7777
A configuration class for the SageMaker Training Compiler.
78-
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
79-
not raised). False if these models should raise an exception.
80-
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
81-
not raised). False if these models should raise an exception.
78+
tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
79+
specifications should be tolerated (exception not raised). False or None, raises an
80+
exception if the script used by this version of the model has dependencies with known
81+
security vulnerabilities.
82+
tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
83+
specifications should be tolerated (exception not raised). False or None, raises
84+
an exception if the version of the model is deprecated.
8285
8386
Returns:
8487
str: the ECR URI for the corresponding SageMaker Docker image.
8588
8689
Raises:
8790
ValueError: If the combination of arguments specified is not supported.
88-
VulnerableJumpStartModelError: If the model is vulnerable.
89-
DeprecatedJumpStartModelError: If the model is deprecated.
91+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
92+
known security vulnerabilities.
93+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
9094
"""
9195
if region is None:
9296
region = JUMPSTART_DEFAULT_REGION_NAME
@@ -102,9 +106,9 @@ def _retrieve_image_uri(
102106
tolerate_deprecated_model=tolerate_deprecated_model,
103107
)
104108

105-
if image_scope == JumpStartScriptScope.INFERENCE.value:
109+
if image_scope == JumpStartScriptScope.INFERENCE:
106110
ecr_specs = model_specs.hosting_ecr_specs
107-
elif image_scope == JumpStartScriptScope.TRAINING.value:
111+
elif image_scope == JumpStartScriptScope.TRAINING:
108112
assert model_specs.training_ecr_specs is not None
109113
ecr_specs = model_specs.training_ecr_specs
110114

@@ -128,11 +132,11 @@ def _retrieve_image_uri(
128132

129133
base_framework_version_override: Optional[str] = None
130134
version_override: Optional[str] = None
131-
if ecr_specs.framework == ModelFramework.HUGGINGFACE.value:
135+
if ecr_specs.framework == ModelFramework.HUGGINGFACE:
132136
base_framework_version_override = ecr_specs.framework_version
133137
version_override = ecr_specs.huggingface_transformers_version
134138

135-
if image_scope == JumpStartScriptScope.TRAINING.value:
139+
if image_scope == JumpStartScriptScope.TRAINING:
136140
return image_uris.get_training_image_uri(
137141
region=region,
138142
framework=ecr_specs.framework,
@@ -181,17 +185,21 @@ def _retrieve_model_uri(
181185
model_scope (str): The model type, i.e. what it is used for.
182186
Valid values: "training" and "inference".
183187
region (str): Region for which to retrieve model S3 URI.
184-
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
185-
not raised). False if these models should raise an exception.
186-
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
187-
not raised). False if these models should raise an exception.
188+
tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
189+
specifications should be tolerated (exception not raised). False or None, raises an
190+
exception if the script used by this version of the model has dependencies with known
191+
security vulnerabilities.
192+
tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
193+
specifications should be tolerated (exception not raised). False or None, raises
194+
an exception if the version of the model is deprecated.
188195
Returns:
189196
str: the model artifact S3 URI for the corresponding model.
190197
191198
Raises:
192199
ValueError: If the combination of arguments specified is not supported.
193-
VulnerableJumpStartModelError: If the model is vulnerable.
194-
DeprecatedJumpStartModelError: If the model is deprecated.
200+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
201+
known security vulnerabilities.
202+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
195203
"""
196204
if region is None:
197205
region = JUMPSTART_DEFAULT_REGION_NAME
@@ -207,9 +215,9 @@ def _retrieve_model_uri(
207215
tolerate_deprecated_model=tolerate_deprecated_model,
208216
)
209217

210-
if model_scope == JumpStartScriptScope.INFERENCE.value:
218+
if model_scope == JumpStartScriptScope.INFERENCE:
211219
model_artifact_key = model_specs.hosting_artifact_key
212-
elif model_scope == JumpStartScriptScope.TRAINING.value:
220+
elif model_scope == JumpStartScriptScope.TRAINING:
213221
assert model_specs.training_artifact_key is not None
214222
model_artifact_key = model_specs.training_artifact_key
215223

@@ -238,17 +246,21 @@ def _retrieve_script_uri(
238246
script_scope (str): The script type, i.e. what it is used for.
239247
Valid values: "training" and "inference".
240248
region (str): Region for which to retrieve model script S3 URI.
241-
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
242-
not raised). False if these models should raise an exception.
243-
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
244-
not raised). False if these models should raise an exception.
249+
tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
250+
specifications should be tolerated (exception not raised). False or None, raises an
251+
exception if the script used by this version of the model has dependencies with known
252+
security vulnerabilities.
253+
tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
254+
specifications should be tolerated (exception not raised). False or None, raises
255+
an exception if the version of the model is deprecated.
245256
Returns:
246257
str: the model script URI for the corresponding model.
247258
248259
Raises:
249260
ValueError: If the combination of arguments specified is not supported.
250-
VulnerableJumpStartModelError: If the model is vulnerable.
251-
DeprecatedJumpStartModelError: If the model is deprecated.
261+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
262+
known security vulnerabilities.
263+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
252264
"""
253265
if region is None:
254266
region = JUMPSTART_DEFAULT_REGION_NAME
@@ -264,9 +276,9 @@ def _retrieve_script_uri(
264276
tolerate_deprecated_model=tolerate_deprecated_model,
265277
)
266278

267-
if script_scope == JumpStartScriptScope.INFERENCE.value:
279+
if script_scope == JumpStartScriptScope.INFERENCE:
268280
model_script_key = model_specs.hosting_script_key
269-
elif script_scope == JumpStartScriptScope.TRAINING.value:
281+
elif script_scope == JumpStartScriptScope.TRAINING:
270282
assert model_specs.training_script_key is not None
271283
model_script_key = model_specs.training_script_key
272284

src/sagemaker/jumpstart/utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,19 +164,21 @@ def verify_model_region_and_return_specs(
164164
scope (Optional[str]): scope of the JumpStart model to verify.
165165
region (Optional[str]): region of the JumpStart model to verify and
166166
obtains specs.
167-
tolerate_vulnerable_model (Optional[bool]): True if vulnerable models should be tolerated
168-
(exception not raised). False if these models should raise an exception.
169-
(Default: None).
167+
tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
168+
specifications should be tolerated (exception not raised). False or None, raises an
169+
exception if the script used by this version of the model has dependencies with known
170+
security vulnerabilities. (Default: None).
170171
tolerate_deprecated_model (Optional[bool]): True if deprecated models should be tolerated
171172
(exception not raised). False if these models should raise an exception.
172173
(Default: None).
173174
174175
175176
Raises:
176-
ValueError: If the combination of arguments specified is not supported.
177177
NotImplementedError: If the scope is not supported.
178-
VulnerableJumpStartModelError: If the model is vulnerable.
179-
DeprecatedJumpStartModelError: If the model is deprecated.
178+
ValueError: If the combination of arguments specified is not supported.
179+
VulnerableJumpStartModelError: If any of the dependencies required by the script have
180+
known security vulnerabilities.
181+
DeprecatedJumpStartModelError: If the version of the model is deprecated.
180182
"""
181183

182184
if tolerate_vulnerable_model is None:

src/sagemaker/model_uris.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,13 @@ def retrieve(
4141
the model artifact S3 URI.
4242
model_scope (str): The model type, i.e. what it is used for.
4343
Valid values: "training" and "inference".
44-
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
45-
not raised). False if these models should raise an exception. (Default: None).
46-
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
47-
not raised). False if these models should raise an exception. (Default: None).
44+
tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
45+
specifications should be tolerated (exception not raised). False or None, raises an
46+
exception if the script used by this version of the model has dependencies with known
47+
security vulnerabilities. (Default: None).
48+
tolerate_deprecated_model (Optional[bool]): True if deprecated versions of model
49+
specifications should be tolerated (exception not raised). False or None, raises
50+
an exception if the version of the model is deprecated. (Default: None).
4851
Returns:
4952
str: the model artifact S3 URI for the corresponding model.
5053

src/sagemaker/script_uris.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,13 @@ def retrieve(
4141
model script S3 URI.
4242
script_scope (str): The script type, i.e. what it is used for.
4343
Valid values: "training" and "inference".
44-
tolerate_vulnerable_model (bool): True if vulnerable models should be tolerated (exception
45-
not raised). False if these models should raise an exception. (Default: None).
46-
tolerate_deprecated_model (bool): True if deprecated models should be tolerated (exception
47-
not raised). False if these models should raise an exception. (Default: None).
44+
tolerate_vulnerable_model (Optional[bool]): True if vulnerable versions of model
45+
specifications should be tolerated (exception not raised). False or None, raises an
46+
exception if the script used by this version of the model has dependencies with known
47+
security vulnerabilities. (Default: None).
48+
tolerate_deprecated_model (Optional[bool]): True if deprecated models should be tolerated
49+
(exception not raised). False if these models should raise an exception.
50+
(Default: None).
4851
Returns:
4952
str: the model script URI for the corresponding model.
5053

0 commit comments

Comments
 (0)