12
12
# language governing permissions and limitations under the License.
13
13
"""This module contains functions for obtaining JumpStart ECR and S3 URIs."""
14
14
from __future__ import absolute_import
15
- from typing import Optional
15
+ from typing import Dict , Optional
16
16
from sagemaker import image_uris
17
17
from sagemaker .jumpstart .constants import (
18
18
JUMPSTART_DEFAULT_REGION_NAME ,
19
19
INFERENCE ,
20
20
TRAINING ,
21
21
SUPPORTED_JUMPSTART_SCOPES ,
22
22
ModelFramework ,
23
+ VariableScope ,
23
24
)
24
25
from sagemaker .jumpstart .utils import get_jumpstart_content_bucket
25
26
from sagemaker .jumpstart import accessors as jumpstart_accessors
@@ -93,7 +94,7 @@ def _retrieve_image_uri(
93
94
)
94
95
95
96
model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
96
- region , model_id , model_version
97
+ region = region , model_id = model_id , version = model_version
97
98
)
98
99
99
100
if image_scope == INFERENCE :
@@ -110,19 +111,19 @@ def _retrieve_image_uri(
110
111
if framework is not None and framework != ecr_specs .framework :
111
112
raise ValueError (
112
113
f"Incorrect container framework '{ framework } ' for JumpStart model ID '{ model_id } ' "
113
- f"and version { model_version } '."
114
+ f"and version ' { model_version } '."
114
115
)
115
116
116
117
if version is not None and version != ecr_specs .framework_version :
117
118
raise ValueError (
118
119
f"Incorrect container framework version '{ version } ' for JumpStart model ID "
119
- f"'{ model_id } ' and version { model_version } '."
120
+ f"'{ model_id } ' and version ' { model_version } '."
120
121
)
121
122
122
123
if py_version is not None and py_version != ecr_specs .py_version :
123
124
raise ValueError (
124
125
f"Incorrect python version '{ py_version } ' for JumpStart model ID '{ model_id } ' "
125
- f"and version { model_version } '."
126
+ f"and version ' { model_version } '."
126
127
)
127
128
128
129
base_framework_version_override : Optional [str ] = None
@@ -201,7 +202,7 @@ def _retrieve_model_uri(
201
202
)
202
203
203
204
model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
204
- region , model_id , model_version
205
+ region = region , model_id = model_id , version = model_version
205
206
)
206
207
if model_scope == INFERENCE :
207
208
model_artifact_key = model_specs .hosting_artifact_key
@@ -260,7 +261,7 @@ def _retrieve_script_uri(
260
261
)
261
262
262
263
model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
263
- region , model_id , model_version
264
+ region = region , model_id = model_id , version = model_version
264
265
)
265
266
if script_scope == INFERENCE :
266
267
model_script_key = model_specs .hosting_script_key
@@ -278,3 +279,77 @@ def _retrieve_script_uri(
278
279
script_s3_uri = f"s3://{ bucket } /{ model_script_key } "
279
280
280
281
return script_s3_uri
282
+
283
+
284
+ def _retrieve_default_hyperparameters (
285
+ model_id : str ,
286
+ model_version : str ,
287
+ region : Optional [str ],
288
+ include_container_hyperparameters : bool = False ,
289
+ ):
290
+ """Retrieves the training hyperparameters for the model matching the given arguments.
291
+
292
+ Args:
293
+ model_id (str): JumpStart model ID of the JumpStart model for which to
294
+ retrieve the default hyperparameters.
295
+ model_version (str): Version of the JumpStart model for which to retrieve the
296
+ default hyperparameters.
297
+ region (str): Region for which to retrieve default hyperparameters.
298
+ include_container_hyperparameters (bool): True if container hyperparameters
299
+ should be returned as well. Container hyperparameters are not used to tune
300
+ the specific algorithm, but rather by SageMaker Training to setup
301
+ the training container environment. For example, there is a container hyperparameter
302
+ that indicates the entrypoint script to use. These hyperparameters may be required
303
+ when creating a training job with boto3, however the ``Estimator`` classes
304
+ should take care of adding container hyperparameters to the job. (Default: False).
305
+ Returns:
306
+ dict: the hyperparameters to use for the model.
307
+ """
308
+
309
+ if region is None :
310
+ region = JUMPSTART_DEFAULT_REGION_NAME
311
+
312
+ assert region is not None
313
+
314
+ model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
315
+ region = region , model_id = model_id , version = model_version
316
+ )
317
+
318
+ default_hyperparameters : Dict [str , str ] = {}
319
+ for hyperparameter in model_specs .hyperparameters :
320
+ if (
321
+ include_container_hyperparameters and hyperparameter .scope == VariableScope .CONTAINER
322
+ ) or hyperparameter .scope == VariableScope .ALGORITHM :
323
+ default_hyperparameters [hyperparameter .name ] = str (hyperparameter .default )
324
+ return default_hyperparameters
325
+
326
+
327
+ def _retrieve_default_environment_variables (
328
+ model_id : str ,
329
+ model_version : str ,
330
+ region : Optional [str ],
331
+ ):
332
+ """Retrieves the inference environment variables for the model matching the given arguments.
333
+
334
+ Args:
335
+ model_id (str): JumpStart model ID of the JumpStart model for which to
336
+ retrieve the default environment variables.
337
+ model_version (str): Version of the JumpStart model for which to retrieve the
338
+ default environment variables.
339
+ region (Optional[str]): Region for which to retrieve default environment variables.
340
+
341
+ Returns:
342
+ dict: the inference environment variables to use for the model.
343
+ """
344
+
345
+ if region is None :
346
+ region = JUMPSTART_DEFAULT_REGION_NAME
347
+
348
+ model_specs = jumpstart_accessors .JumpStartModelsAccessor .get_model_specs (
349
+ region = region , model_id = model_id , version = model_version
350
+ )
351
+
352
+ default_environment_variables : Dict [str , str ] = {}
353
+ for environment_variable in model_specs .inference_environment_variables :
354
+ default_environment_variables [environment_variable .name ] = str (environment_variable .default )
355
+ return default_environment_variables
0 commit comments