Skip to content

Commit 3c10d48

Browse files
committed
change: update name of jumpstart models accessor, fix small issues
1 parent 4e4590c commit 3c10d48

File tree

16 files changed

+68
-66
lines changed

16 files changed

+68
-66
lines changed

src/sagemaker/jumpstart/accessors.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def get_sagemaker_version() -> str:
3434
return SageMakerSettings._parsed_sagemaker_version
3535

3636

37-
class JumpStartModelsCache(object):
37+
class JumpStartModelsAccessor(object):
3838
"""Static class for storing the JumpStart models cache."""
3939

4040
_cache: Optional[cache.JumpStartModelsCache] = None
@@ -67,15 +67,17 @@ def _validate_and_mutate_region_cache_kwargs(
6767

6868
@staticmethod
6969
def _set_cache_and_region(region: str, cache_kwargs: dict) -> None:
70-
"""Sets ``JumpStartModelsCache._cache`` and ``JumpStartModelsCache._curr_region``.
70+
"""Sets ``JumpStartModelsAccessor._cache`` and ``JumpStartModelsAccessor._curr_region``.
7171
7272
Args:
7373
region (str): region for which to retrieve header/spec.
7474
cache_kwargs (dict): kwargs to pass to ``JumpStartModelsCache``.
7575
"""
76-
if JumpStartModelsCache._cache is None or region != JumpStartModelsCache._curr_region:
77-
JumpStartModelsCache._cache = cache.JumpStartModelsCache(region=region, **cache_kwargs)
78-
JumpStartModelsCache._curr_region = region
76+
if JumpStartModelsAccessor._cache is None or region != JumpStartModelsAccessor._curr_region:
77+
JumpStartModelsAccessor._cache = cache.JumpStartModelsCache(
78+
region=region, **cache_kwargs
79+
)
80+
JumpStartModelsAccessor._curr_region = region
7981

8082
@staticmethod
8183
def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader:
@@ -86,12 +88,12 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
8688
model_id (str): model id to retrieve.
8789
version (str): semantic version to retrieve for the model id.
8890
"""
89-
cache_kwargs = JumpStartModelsCache._validate_and_mutate_region_cache_kwargs(
90-
JumpStartModelsCache._cache_kwargs, region
91+
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
92+
JumpStartModelsAccessor._cache_kwargs, region
9193
)
92-
JumpStartModelsCache._set_cache_and_region(region, cache_kwargs)
93-
assert JumpStartModelsCache._cache is not None
94-
return JumpStartModelsCache._cache.get_header(model_id, version)
94+
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
95+
assert JumpStartModelsAccessor._cache is not None
96+
return JumpStartModelsAccessor._cache.get_header(model_id, version)
9597

9698
@staticmethod
9799
def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelSpecs:
@@ -102,12 +104,12 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
102104
model_id (str): model id to retrieve.
103105
version (str): semantic version to retrieve for the model id.
104106
"""
105-
cache_kwargs = JumpStartModelsCache._validate_and_mutate_region_cache_kwargs(
106-
JumpStartModelsCache._cache_kwargs, region
107+
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
108+
JumpStartModelsAccessor._cache_kwargs, region
107109
)
108-
JumpStartModelsCache._set_cache_and_region(region, cache_kwargs)
109-
assert JumpStartModelsCache._cache is not None
110-
return JumpStartModelsCache._cache.get_specs(model_id, version)
110+
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
111+
assert JumpStartModelsAccessor._cache is not None
112+
return JumpStartModelsAccessor._cache.get_specs(model_id, version)
111113

112114
@staticmethod
113115
def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
@@ -120,18 +122,18 @@ def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
120122
cache_kwargs (str): cache kwargs to validate.
121123
region (str): Optional. The region to validate along with the kwargs.
122124
"""
123-
cache_kwargs = JumpStartModelsCache._validate_and_mutate_region_cache_kwargs(
125+
cache_kwargs = JumpStartModelsAccessor._validate_and_mutate_region_cache_kwargs(
124126
cache_kwargs, region
125127
)
126-
JumpStartModelsCache._cache_kwargs = cache_kwargs
128+
JumpStartModelsAccessor._cache_kwargs = cache_kwargs
127129
if region is None:
128-
JumpStartModelsCache._cache = cache.JumpStartModelsCache(
129-
**JumpStartModelsCache._cache_kwargs
130+
JumpStartModelsAccessor._cache = cache.JumpStartModelsCache(
131+
**JumpStartModelsAccessor._cache_kwargs
130132
)
131133
else:
132-
JumpStartModelsCache._curr_region = region
133-
JumpStartModelsCache._cache = cache.JumpStartModelsCache(
134-
region=region, **JumpStartModelsCache._cache_kwargs
134+
JumpStartModelsAccessor._curr_region = region
135+
JumpStartModelsAccessor._cache = cache.JumpStartModelsCache(
136+
region=region, **JumpStartModelsAccessor._cache_kwargs
135137
)
136138

137139
@staticmethod
@@ -146,4 +148,4 @@ def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = Non
146148
region (str): The region to validate along with the kwargs.
147149
"""
148150
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
149-
JumpStartModelsCache.set_cache_kwargs(cache_kwargs_dict, region)
151+
JumpStartModelsAccessor.set_cache_kwargs(cache_kwargs_dict, region)

src/sagemaker/jumpstart/artifacts.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _retrieve_image_uri(
4949
Args:
5050
model_id (str): JumpStart model ID for which to retrieve image URI.
5151
model_version (str): Version of the JumpStart model for which to retrieve
52-
the image URI (default: None).
52+
the image URI.
5353
image_scope (str): The image type, i.e. what it is used for.
5454
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
5555
``image_scope`` is ignored.
@@ -67,12 +67,10 @@ def _retrieve_image_uri(
6767
container_version (str): the version of docker image.
6868
Ideally the value of parameter should be created inside the framework.
6969
For custom use, see the list of supported container versions:
70-
https://github.com/aws/deep-learning-containers/blob/master/available_images.md
71-
(default: None).
70+
https://github.com/aws/deep-learning-containers/blob/master/available_images.md.
7271
distribution (dict): A dictionary with information on how to run distributed training
7372
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
74-
A configuration class for the SageMaker Training Compiler
75-
(default: None).
73+
A configuration class for the SageMaker Training Compiler.
7674
7775
Returns:
7876
str: the ECR URI for the corresponding SageMaker Docker image.
@@ -94,7 +92,7 @@ def _retrieve_image_uri(
9492
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
9593
)
9694

97-
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
95+
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
9896
region, model_id, model_version
9997
)
10098

@@ -202,7 +200,7 @@ def _retrieve_model_uri(
202200
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
203201
)
204202

205-
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
203+
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
206204
region, model_id, model_version
207205
)
208206
if model_scope == INFERENCE:
@@ -261,7 +259,7 @@ def _retrieve_script_uri(
261259
f"JumpStart models only support scopes: {', '.join(SUPPORTED_JUMPSTART_SCOPES)}."
262260
)
263261

264-
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
262+
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
265263
region, model_id, model_version
266264
)
267265
if script_scope == INFERENCE:

tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
2323

2424

25-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsCache.get_model_specs")
25+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
2626
def test_jumpstart_catboost_image_uri(patched_get_model_specs):
2727

2828
patched_get_model_specs.side_effect = get_prototype_model_spec
@@ -31,7 +31,7 @@ def test_jumpstart_catboost_image_uri(patched_get_model_specs):
3131
instance_type = "ml.p2.xlarge"
3232
region = "us-west-2"
3333

34-
model_specs = accessors.JumpStartModelsCache.get_model_specs(region, model_id, model_version)
34+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version)
3535

3636
# inference
3737
uri = image_uris.retrieve(

tests/unit/sagemaker/image_uris/jumpstart/test_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sagemaker.jumpstart import constants as sagemaker_constants
2222

2323

24-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsCache.get_model_specs")
24+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
2525
def test_jumpstart_common_image_uri(patched_get_model_specs):
2626

2727
patched_get_model_specs.side_effect = get_spec_from_base_spec

tests/unit/sagemaker/image_uris/jumpstart/test_huggingface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
2323

2424

25-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsCache.get_model_specs")
25+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
2626
def test_jumpstart_huggingface_image_uri(patched_get_model_specs):
2727

2828
patched_get_model_specs.side_effect = get_prototype_model_spec
@@ -31,7 +31,7 @@ def test_jumpstart_huggingface_image_uri(patched_get_model_specs):
3131
instance_type = "ml.p2.xlarge"
3232
region = "us-west-2"
3333

34-
model_specs = accessors.JumpStartModelsCache.get_model_specs(region, model_id, model_version)
34+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version)
3535

3636
# inference
3737
uri = image_uris.retrieve(

tests/unit/sagemaker/image_uris/jumpstart/test_lightgbm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
2323

2424

25-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsCache.get_model_specs")
25+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
2626
def test_jumpstart_lightgbm_image_uri(patched_get_model_specs):
2727

2828
patched_get_model_specs.side_effect = get_prototype_model_spec
@@ -31,7 +31,7 @@ def test_jumpstart_lightgbm_image_uri(patched_get_model_specs):
3131
instance_type = "ml.p2.xlarge"
3232
region = "us-west-2"
3333

34-
model_specs = accessors.JumpStartModelsCache.get_model_specs(region, model_id, model_version)
34+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version)
3535

3636
# inference
3737
uri = image_uris.retrieve(

tests/unit/sagemaker/image_uris/jumpstart/test_mxnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
2323

2424

25-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsCache.get_model_specs")
25+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
2626
def test_jumpstart_mxnet_image_uri(patched_get_model_specs):
2727

2828
patched_get_model_specs.side_effect = get_prototype_model_spec
@@ -31,7 +31,7 @@ def test_jumpstart_mxnet_image_uri(patched_get_model_specs):
3131
instance_type = "ml.p2.xlarge"
3232
region = "us-west-2"
3333

34-
model_specs = accessors.JumpStartModelsCache.get_model_specs(region, model_id, model_version)
34+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version)
3535

3636
# inference
3737
uri = image_uris.retrieve(

tests/unit/sagemaker/image_uris/jumpstart/test_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
2323

2424

25-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsCache.get_model_specs")
25+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
2626
def test_jumpstart_pytorch_image_uri(patched_get_model_specs):
2727

2828
patched_get_model_specs.side_effect = get_prototype_model_spec
@@ -31,7 +31,7 @@ def test_jumpstart_pytorch_image_uri(patched_get_model_specs):
3131
instance_type = "ml.p2.xlarge"
3232
region = "us-west-2"
3333

34-
model_specs = accessors.JumpStartModelsCache.get_model_specs(region, model_id, model_version)
34+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version)
3535

3636
# inference
3737
uri = image_uris.retrieve(

tests/unit/sagemaker/image_uris/jumpstart/test_sklearn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
2424

2525

26-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsCache.get_model_specs")
26+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
2727
def test_jumpstart_sklearn_image_uri(patched_get_model_specs):
2828

2929
patched_get_model_specs.side_effect = get_prototype_model_spec
@@ -32,7 +32,7 @@ def test_jumpstart_sklearn_image_uri(patched_get_model_specs):
3232
instance_type = "ml.m2.xlarge"
3333
region = "us-west-2"
3434

35-
model_specs = accessors.JumpStartModelsCache.get_model_specs(region, model_id, model_version)
35+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version)
3636

3737
# inference
3838
uri = image_uris.retrieve(

tests/unit/sagemaker/image_uris/jumpstart/test_tensorflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
2323

2424

25-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsCache.get_model_specs")
25+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
2626
def test_jumpstart_tensorflow_image_uri(patched_get_model_specs):
2727

2828
patched_get_model_specs.side_effect = get_prototype_model_spec
@@ -31,7 +31,7 @@ def test_jumpstart_tensorflow_image_uri(patched_get_model_specs):
3131
instance_type = "ml.p2.xlarge"
3232
region = "us-west-2"
3333

34-
model_specs = accessors.JumpStartModelsCache.get_model_specs(region, model_id, model_version)
34+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version)
3535

3636
# inference
3737
uri = image_uris.retrieve(

tests/unit/sagemaker/image_uris/jumpstart/test_xgboost.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec
2323

2424

25-
@patch("sagemaker.jumpstart.accessors.JumpStartModelsCache.get_model_specs")
25+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
2626
def test_jumpstart_xgboost_image_uri(patched_get_model_specs):
2727

2828
patched_get_model_specs.side_effect = get_prototype_model_spec
@@ -31,7 +31,7 @@ def test_jumpstart_xgboost_image_uri(patched_get_model_specs):
3131
instance_type = "ml.p2.xlarge"
3232
region = "us-west-2"
3333

34-
model_specs = accessors.JumpStartModelsCache.get_model_specs(region, model_id, model_version)
34+
model_specs = accessors.JumpStartModelsAccessor.get_model_specs(region, model_id, model_version)
3535

3636
# inference
3737
uri = image_uris.retrieve(

0 commit comments

Comments
 (0)