Skip to content

Commit 5b98f42

Browse files
committed
change: cleanup code, remove redundant tests
1 parent f6ade25 commit 5b98f42

32 files changed

+165
-817
lines changed

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def read_version():
7979
"fabric>=2.0",
8080
"requests>=2.20.0, <3",
8181
"sagemaker-experiments",
82-
"regex",
8382
],
8483
)
8584

src/sagemaker/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2433,8 +2433,8 @@ def training_image_uri(self, region=None):
24332433
training.
24342434
24352435
Args:
2436-
region: Region to use for image uri.
2437-
Default: Region associated with SageMaker session.
2436+
region (str): Optional. AWS region to use for image URI. Default: AWS region associated
2437+
with the SageMaker session.
24382438
24392439
Returns:
24402440
str: The URI of the Docker image.

src/sagemaker/image_uris.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def retrieve(
4545
training_compiler_config=None,
4646
model_id=None,
4747
model_version=None,
48-
):
48+
) -> str:
4949
"""Retrieves the ECR URI for the Docker image matching the given arguments.
5050
5151
Ideally this function should not be called directly, rather it should be called from the
@@ -75,8 +75,10 @@ def retrieve(
7575
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
7676
A configuration class for the SageMaker Training Compiler
7777
(default: None).
78-
model_id (str): JumpStart model id for which to retrieve image URI.
79-
model_version (str): JumpStart model version for which to retrieve image URI.
78+
model_id (str): JumpStart model ID for which to retrieve image URI
79+
(default: None).
80+
model_version (str): Version of the JumpStart model for which to retrieve the
81+
image URI (default: None).
8082
8183
Returns:
8284
str: the ECR URI for the corresponding SageMaker Docker image.
@@ -85,8 +87,11 @@ def retrieve(
8587
ValueError: If the combination of arguments specified is not supported.
8688
"""
8789
if is_jumpstart_model_input(model_id, model_version):
90+
91+
# adding assert statements to satisfy mypy type checker
8892
assert model_id is not None
8993
assert model_version is not None
94+
9095
return artifacts._retrieve_image_uri(
9196
model_id,
9297
model_version,

src/sagemaker/jumpstart/accessors.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,17 @@
2121
class SageMakerSettings(object):
2222
"""Static class for storing the SageMaker settings."""
2323

24-
_PARSED_SAGEMAKER_VERSION = ""
24+
_parsed_sagemaker_version = ""
2525

2626
@staticmethod
2727
def set_sagemaker_version(version: str) -> None:
2828
"""Set SageMaker version."""
29-
SageMakerSettings._PARSED_SAGEMAKER_VERSION = version
29+
SageMakerSettings._parsed_sagemaker_version = version
3030

3131
@staticmethod
3232
def get_sagemaker_version() -> str:
3333
"""Return SageMaker version."""
34-
return SageMakerSettings._PARSED_SAGEMAKER_VERSION
34+
return SageMakerSettings._parsed_sagemaker_version
3535

3636

3737
class JumpStartModelsCache(object):
@@ -43,7 +43,7 @@ class JumpStartModelsCache(object):
4343
_cache_kwargs: Dict[str, Any] = {}
4444

4545
@staticmethod
46-
def _validate_region_cache_kwargs(
46+
def _validate_and_mutate_region_cache_kwargs(
4747
cache_kwargs: Optional[Dict[str, Any]] = None, region: Optional[str] = None
4848
) -> Dict[str, Any]:
4949
"""Returns cache_kwargs with region argument removed if present.
@@ -74,7 +74,7 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
7474
model_id (str): model id to retrieve.
7575
version (str): semantic version to retrieve for the model id.
7676
"""
77-
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(
77+
cache_kwargs = JumpStartModelsCache._validate_and_mutate_region_cache_kwargs(
7878
JumpStartModelsCache._cache_kwargs, region
7979
)
8080
if JumpStartModelsCache._cache is None or region != JumpStartModelsCache._curr_region:
@@ -92,7 +92,7 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
9292
model_id (str): model id to retrieve.
9393
version (str): semantic version to retrieve for the model id.
9494
"""
95-
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(
95+
cache_kwargs = JumpStartModelsCache._validate_and_mutate_region_cache_kwargs(
9696
JumpStartModelsCache._cache_kwargs, region
9797
)
9898
if JumpStartModelsCache._cache is None or region != JumpStartModelsCache._curr_region:
@@ -103,7 +103,7 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
103103

104104
@staticmethod
105105
def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
106-
"""Sets cache kwargs, clear the cache.
106+
"""Sets cache kwargs, clears the cache.
107107
108108
Raises:
109109
ValueError: If region in `cache_kwargs` is inconsistent with `region` argument.
@@ -112,7 +112,9 @@ def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
112112
cache_kwargs (str): cache kwargs to validate.
113113
region (str): Optional. The region to validate along with the kwargs.
114114
"""
115-
cache_kwargs = JumpStartModelsCache._validate_region_cache_kwargs(cache_kwargs, region)
115+
cache_kwargs = JumpStartModelsCache._validate_and_mutate_region_cache_kwargs(
116+
cache_kwargs, region
117+
)
116118
JumpStartModelsCache._cache_kwargs = cache_kwargs
117119
if region is None:
118120
JumpStartModelsCache._cache = cache.JumpStartModelsCache(
@@ -125,7 +127,7 @@ def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
125127
)
126128

127129
@staticmethod
128-
def reset_cache(cache_kwargs: Dict[str, Any] = None, region: str = None) -> None:
130+
def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = None) -> None:
129131
"""Resets cache, optionally allowing cache kwargs to be passed to the new cache.
130132
131133
Raises:

src/sagemaker/jumpstart/artifacts.py

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""This module contains functions for obtainining JumpStart artifacts."""
13+
"""This module contains functions for obtaining JumpStart ECR and S3 URIs."""
1414
from __future__ import absolute_import
1515
from typing import Optional
1616
from sagemaker import image_uris
@@ -42,13 +42,14 @@ def _retrieve_image_uri(
4242
):
4343
"""Retrieves the container image URI for JumpStart models.
4444
45-
Only `model_id` and `model_version` are required to be non-None;
45+
Only `model_id`, `model_version`, and `image_scope` are required;
4646
the rest of the fields are auto-populated.
4747
4848
4949
Args:
50-
model_id (str): JumpStart model id for which to retrieve image URI.
51-
model_version (str): JumpStart model version for which to retrieve image URI.
50+
model_id (str): JumpStart model ID for which to retrieve image URI.
51+
model_version (str): Version of the JumpStart model for which to retrieve
52+
the image URI (default: None).
5253
framework (str): The name of the framework or algorithm.
5354
region (str): The AWS region.
5455
version (str): The framework or algorithm version. This is required if there is
@@ -89,7 +90,9 @@ def _retrieve_image_uri(
8990
"Must specify `image_scope` argument to retrieve image uri for JumpStart models."
9091
)
9192
if image_scope not in SUPPORTED_JUMPSTART_SCOPES:
92-
raise ValueError("JumpStart models only support inference and training.")
93+
raise ValueError(
94+
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
95+
)
9396

9497
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
9598
region, model_id, model_version
@@ -99,25 +102,33 @@ def _retrieve_image_uri(
99102
ecr_specs = model_specs.hosting_ecr_specs
100103
elif image_scope == TRAINING:
101104
if not model_specs.training_supported:
102-
raise ValueError(f"JumpStart model id '{model_id}' does not support training.")
105+
raise ValueError(
106+
f"JumpStart model ID '{model_id}' and version '{model_version}' "
107+
"does not support training."
108+
)
103109
assert model_specs.training_ecr_specs is not None
104110
ecr_specs = model_specs.training_ecr_specs
105111

106112
if framework is not None and framework != ecr_specs.framework:
107-
raise ValueError(f"Bad value for container framework for JumpStart model: '{framework}'.")
113+
raise ValueError(
114+
f"Incorrect container framework '{framework}' for JumpStart model ID '{model_id}' "
115+
"and version {model_version}'."
116+
)
108117

109118
if version is not None and version != ecr_specs.framework_version:
110119
raise ValueError(
111-
f"Bad value for container framework version for JumpStart model: '{version}'."
120+
f"Incorrect container framework version '{version}' for JumpStart model ID "
121+
f"'{model_id}' and version {model_version}'."
112122
)
113123

114124
if py_version is not None and py_version != ecr_specs.py_version:
115125
raise ValueError(
116-
f"Bad value for container python version for JumpStart model: '{py_version}'."
126+
f"Incorrect python version '{py_version}' for JumpStart model ID '{model_id}' "
127+
"and version {model_version}'."
117128
)
118129

119-
base_framework_version_override = None
120-
version_override = None
130+
base_framework_version_override: Optional[str] = None
131+
version_override: Optional[str] = None
121132
if ecr_specs.framework == ModelFramework.HUGGINGFACE.value:
122133
base_framework_version_override = ecr_specs.framework_version
123134
version_override = ecr_specs.huggingface_transformers_version
@@ -162,8 +173,10 @@ def _retrieve_model_uri(
162173
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
163174
164175
Args:
165-
model_id (str): JumpStart model id for which to retrieve model S3 URI.
166-
model_version (str): JumpStart model version for which to retrieve model S3 URI.
176+
model_id (str): JumpStart model ID of the JumpStart model for which to retrieve
177+
the model artifact S3 URI.
178+
model_version (str): Version of the JumpStart model for which to retrieve the model
179+
artifact S3 URI.
167180
model_scope (str): The model type, i.e. what it is used for.
168181
Valid values: "training" and "inference".
169182
region (str): Region for which to retrieve model S3 URI.
@@ -185,7 +198,9 @@ def _retrieve_model_uri(
185198
)
186199

187200
if model_scope not in SUPPORTED_JUMPSTART_SCOPES:
188-
raise ValueError("JumpStart models only support inference and training.")
201+
raise ValueError(
202+
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
203+
)
189204

190205
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
191206
region, model_id, model_version
@@ -194,7 +209,10 @@ def _retrieve_model_uri(
194209
model_artifact_key = model_specs.hosting_artifact_key
195210
elif model_scope == TRAINING:
196211
if not model_specs.training_supported:
197-
raise ValueError(f"JumpStart model id '{model_id}' does not support training.")
212+
raise ValueError(
213+
f"JumpStart model ID '{model_id}' and version '{model_version}' "
214+
"does not support training."
215+
)
198216
assert model_specs.training_artifact_key is not None
199217
model_artifact_key = model_specs.training_artifact_key
200218

@@ -211,11 +229,13 @@ def _retrieve_script_uri(
211229
script_scope: Optional[str],
212230
region: Optional[str],
213231
):
214-
"""Retrieves the model script s3 URI for the model matching the given arguments.
232+
"""Retrieves the script S3 URI associated with the model matching the given arguments.
215233
216234
Args:
217-
model_id (str): JumpStart model id for which to retrieve model script S3 URI.
218-
model_version (str): JumpStart model version for which to retrieve model script S3 URI.
235+
model_id (str): JumpStart model ID of the JumpStart model for which to
236+
retrieve the script S3 URI.
237+
model_version (str): Version of the JumpStart model for which to
238+
retrieve the model script S3 URI.
219239
script_scope (str): The script type, i.e. what it is used for.
220240
Valid values: "training" and "inference".
221241
region (str): Region for which to retrieve model script S3 URI.
@@ -237,7 +257,9 @@ def _retrieve_script_uri(
237257
)
238258

239259
if script_scope not in SUPPORTED_JUMPSTART_SCOPES:
240-
raise ValueError("JumpStart models only support inference and training.")
260+
raise ValueError(
261+
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
262+
)
241263

242264
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
243265
region, model_id, model_version
@@ -246,7 +268,10 @@ def _retrieve_script_uri(
246268
model_script_key = model_specs.hosting_script_key
247269
elif script_scope == TRAINING:
248270
if not model_specs.training_supported:
249-
raise ValueError(f"JumpStart model id '{model_id}' does not support training.")
271+
raise ValueError(
272+
f"JumpStart model ID '{model_id}' and version '{model_version}' "
273+
"does not support training."
274+
)
250275
assert model_specs.training_script_key is not None
251276
model_script_key = model_specs.training_script_key
252277

src/sagemaker/jumpstart/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def _select_version(
297297
spec = SpecifierSet(f"=={semantic_version_str}")
298298
available_versions_filtered = list(spec.filter(available_versions))
299299
return (
300-
str(available_versions_filtered[0]) if available_versions_filtered != [] else None
300+
str(max(available_versions_filtered)) if available_versions_filtered != [] else None
301301
)
302302

303303
def _get_header_impl(

src/sagemaker/jumpstart/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ def is_jumpstart_model_input(model_id: Optional[str], version: Optional[str]) ->
122122
are None, and raises an exception if one argument is None but the other isn't.
123123
124124
Args:
125-
model_id (str): Optional. Model id of JumpStart model.
126-
version (str): Optional. Version for JumpStart model.
125+
model_id (str): Optional. Model ID of the JumpStart model.
126+
version (str): Optional. Version of the JumpStart model.
127127
128128
Raises:
129129
ValueError: If only one of the two arguments is None.

src/sagemaker/model_uris.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Functions for generating S3 model artifact URIs for pre-built SageMaker models."""
13+
"""Accessors to retrieve the model artifact S3 URI of pretrained ML models."""
1414
from __future__ import absolute_import
1515

1616
import logging
@@ -29,13 +29,15 @@ def retrieve(
2929
model_id=None,
3030
model_version: Optional[str] = None,
3131
model_scope: Optional[str] = None,
32-
):
32+
) -> str:
3333
"""Retrieves the model artifact S3 URI for the model matching the given arguments.
3434
3535
Args:
3636
region (str): Region for which to retrieve model S3 URI.
37-
model_id (str): JumpStart model id for which to retrieve model S3 URI.
38-
model_version (str): JumpStart model version for which to retrieve model S3 URI.
37+
model_id (str): JumpStart model ID of the JumpStart model for which to retrieve
38+
the model artifact S3 URI.
39+
model_version (str): Version of the JumpStart model for which to retrieve
40+
the model artifact S3 URI.
3941
model_scope (str): The model type, i.e. what it is used for.
4042
Valid values: "training" and "inference".
4143
Returns:
@@ -47,6 +49,8 @@ def retrieve(
4749
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
4850
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
4951

52+
# mypy type checking require these assertions
5053
assert model_id is not None
5154
assert model_version is not None
55+
5256
return artifacts._retrieve_model_uri(model_id, model_version, model_scope, region)

src/sagemaker/script_uris.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Functions for generating S3 model script URIs for pre-built SageMaker models."""
13+
"""Accessors to retrieve the script S3 URI to be run pretrained ML models
14+
in SageMaker containers.
15+
"""
1416
from __future__ import absolute_import
1517

1618
import logging
@@ -27,13 +29,15 @@ def retrieve(
2729
model_id=None,
2830
model_version=None,
2931
script_scope=None,
30-
):
31-
"""Retrieves the model script s3 URI for the model matching the given arguments.
32+
) -> str:
33+
"""Retrieves the script S3 URI associated with the model matching the given arguments.
3234
3335
Args:
3436
region (str): Region for which to retrieve model script S3 URI.
35-
model_id (str): JumpStart model id for which to retrieve model script S3 URI.
36-
model_version (str): JumpStart model version for which to retrieve model script S3 URI.
37+
model_id (str): JumpStart model ID of the JumpStart model for which to
38+
retrieve the script S3 URI.
39+
model_version (str): Version of the JumpStart model for which to retrieve the
40+
model script S3 URI.
3741
script_scope (str): The script type, i.e. what it is used for.
3842
Valid values: "training" and "inference".
3943
Returns:
@@ -45,6 +49,8 @@ def retrieve(
4549
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
4650
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
4751

52+
# mypy type checking require these assertions
4853
assert model_id is not None
4954
assert model_version is not None
55+
5056
return artifacts._retrieve_script_uri(model_id, model_version, script_scope, region)

0 commit comments

Comments
 (0)