Skip to content

Commit 4e4590c

Browse files
committed
change: create helper function for setting kwargs and region for JumpStart cache
1 parent 64b51d6 commit 4e4590c

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

src/sagemaker/jumpstart/accessors.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,18 @@ def _validate_and_mutate_region_cache_kwargs(
6565
del cache_kwargs_dict["region"]
6666
return cache_kwargs_dict
6767

68+
@staticmethod
69+
def _set_cache_and_region(region: str, cache_kwargs: dict) -> None:
70+
"""Sets ``JumpStartModelsCache._cache`` and ``JumpStartModelsCache._curr_region``.
71+
72+
Args:
73+
region (str): region for which to retrieve header/spec.
74+
cache_kwargs (dict): kwargs to pass to ``JumpStartModelsCache``.
75+
"""
76+
if JumpStartModelsCache._cache is None or region != JumpStartModelsCache._curr_region:
77+
JumpStartModelsCache._cache = cache.JumpStartModelsCache(region=region, **cache_kwargs)
78+
JumpStartModelsCache._curr_region = region
79+
6880
@staticmethod
6981
def get_model_header(region: str, model_id: str, version: str) -> JumpStartModelHeader:
7082
"""Returns model header from JumpStart models cache.
@@ -77,9 +89,7 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
7789
cache_kwargs = JumpStartModelsCache._validate_and_mutate_region_cache_kwargs(
7890
JumpStartModelsCache._cache_kwargs, region
7991
)
80-
if JumpStartModelsCache._cache is None or region != JumpStartModelsCache._curr_region:
81-
JumpStartModelsCache._cache = cache.JumpStartModelsCache(region=region, **cache_kwargs)
82-
JumpStartModelsCache._curr_region = region
92+
JumpStartModelsCache._set_cache_and_region(region, cache_kwargs)
8393
assert JumpStartModelsCache._cache is not None
8494
return JumpStartModelsCache._cache.get_header(model_id, version)
8595

@@ -95,9 +105,7 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
95105
cache_kwargs = JumpStartModelsCache._validate_and_mutate_region_cache_kwargs(
96106
JumpStartModelsCache._cache_kwargs, region
97107
)
98-
if JumpStartModelsCache._cache is None or region != JumpStartModelsCache._curr_region:
99-
JumpStartModelsCache._cache = cache.JumpStartModelsCache(region=region, **cache_kwargs)
100-
JumpStartModelsCache._curr_region = region
108+
JumpStartModelsCache._set_cache_and_region(region, cache_kwargs)
101109
assert JumpStartModelsCache._cache is not None
102110
return JumpStartModelsCache._cache.get_specs(model_id, version)
103111

0 commit comments

Comments
 (0)