@@ -65,6 +65,18 @@ def _validate_and_mutate_region_cache_kwargs(
65
65
del cache_kwargs_dict ["region" ]
66
66
return cache_kwargs_dict
67
67
68
+ @staticmethod
69
+ def _set_cache_and_region (region : str , cache_kwargs : dict ) -> None :
70
+ """Sets ``JumpStartModelsCache._cache`` and ``JumpStartModelsCache._curr_region``.
71
+
72
+ Args:
73
+ region (str): region for which to retrieve header/spec.
74
+ cache_kwargs (dict): kwargs to pass to ``JumpStartModelsCache``.
75
+ """
76
+ if JumpStartModelsCache ._cache is None or region != JumpStartModelsCache ._curr_region :
77
+ JumpStartModelsCache ._cache = cache .JumpStartModelsCache (region = region , ** cache_kwargs )
78
+ JumpStartModelsCache ._curr_region = region
79
+
68
80
@staticmethod
69
81
def get_model_header (region : str , model_id : str , version : str ) -> JumpStartModelHeader :
70
82
"""Returns model header from JumpStart models cache.
@@ -77,9 +89,7 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
77
89
cache_kwargs = JumpStartModelsCache ._validate_and_mutate_region_cache_kwargs (
78
90
JumpStartModelsCache ._cache_kwargs , region
79
91
)
80
- if JumpStartModelsCache ._cache is None or region != JumpStartModelsCache ._curr_region :
81
- JumpStartModelsCache ._cache = cache .JumpStartModelsCache (region = region , ** cache_kwargs )
82
- JumpStartModelsCache ._curr_region = region
92
+ JumpStartModelsCache ._set_cache_and_region (region , cache_kwargs )
83
93
assert JumpStartModelsCache ._cache is not None
84
94
return JumpStartModelsCache ._cache .get_header (model_id , version )
85
95
@@ -95,9 +105,7 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
95
105
cache_kwargs = JumpStartModelsCache ._validate_and_mutate_region_cache_kwargs (
96
106
JumpStartModelsCache ._cache_kwargs , region
97
107
)
98
- if JumpStartModelsCache ._cache is None or region != JumpStartModelsCache ._curr_region :
99
- JumpStartModelsCache ._cache = cache .JumpStartModelsCache (region = region , ** cache_kwargs )
100
- JumpStartModelsCache ._curr_region = region
108
+ JumpStartModelsCache ._set_cache_and_region (region , cache_kwargs )
101
109
assert JumpStartModelsCache ._cache is not None
102
110
return JumpStartModelsCache ._cache .get_specs (model_id , version )
103
111
0 commit comments