Skip to content

Commit 00f23e6

Browse files
authored
feature: jumpstart hyperparameters and environment variables (#2850)
1 parent b691d3d commit 00f23e6

34 files changed

+1685
-130
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Accessors to retrieve environment variables for hosting containers."""
14+
15+
from __future__ import absolute_import
16+
17+
import logging
18+
from typing import Dict
19+
20+
from sagemaker.jumpstart import utils as jumpstart_utils
21+
from sagemaker.jumpstart import artifacts
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
def retrieve_default(
27+
region=None,
28+
model_id=None,
29+
model_version=None,
30+
) -> Dict[str, str]:
31+
"""Retrieves the default container environment variables for the model matching the arguments.
32+
33+
Args:
34+
region (str): Optional. Region for which to retrieve default environment variables.
35+
(Default: None).
36+
model_id (str): Optional. Model ID of the model for which to
37+
retrieve the default environment variables. (Default: None).
38+
model_version (str): Optional. Version of the model for which to retrieve the
39+
default environment variables. (Default: None).
40+
Returns:
41+
dict: the variables to use for the model.
42+
43+
Raises:
44+
ValueError: If the combination of arguments specified is not supported.
45+
"""
46+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
47+
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
48+
49+
# mypy type checking require these assertions
50+
assert model_id is not None
51+
assert model_version is not None
52+
53+
return artifacts._retrieve_default_environment_variables(model_id, model_version, region)

src/sagemaker/hyperparameters.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Accessors to retrieve hyperparameters for training jobs."""
14+
15+
from __future__ import absolute_import
16+
17+
import logging
18+
from typing import Dict
19+
20+
from sagemaker.jumpstart import utils as jumpstart_utils
21+
from sagemaker.jumpstart import artifacts
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
def retrieve_default(
27+
region=None,
28+
model_id=None,
29+
model_version=None,
30+
include_container_hyperparameters=False,
31+
) -> Dict[str, str]:
32+
"""Retrieves the default training hyperparameters for the model matching the given arguments.
33+
34+
Args:
35+
region (str): Region for which to retrieve default hyperparameters. (Default: None).
36+
model_id (str): Model ID of the model for which to
37+
retrieve the default hyperparameters. (Default: None).
38+
model_version (str): Version of the model for which to retrieve the
39+
default hyperparameters. (Default: None).
40+
include_container_hyperparameters (bool): True if container hyperparameters
41+
should be returned as well. Container hyperparameters are not used to tune
42+
the specific algorithm, but rather by SageMaker Training to setup
43+
the training container environment. For example, there is a container hyperparameter
44+
that indicates the entrypoint script to use. These hyperparameters may be required
45+
when creating a training job with boto3, however the ``Estimator`` classes
46+
should take care of adding container hyperparameters to the job. (Default: False).
47+
Returns:
48+
dict: the hyperparameters to use for the model.
49+
50+
Raises:
51+
ValueError: If the combination of arguments specified is not supported.
52+
"""
53+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
54+
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
55+
56+
return artifacts._retrieve_default_hyperparameters(
57+
model_id, model_version, region, include_container_hyperparameters
58+
)

src/sagemaker/jumpstart/accessors.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,9 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
9393
)
9494
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
9595
assert JumpStartModelsAccessor._cache is not None
96-
return JumpStartModelsAccessor._cache.get_header(model_id, version)
96+
return JumpStartModelsAccessor._cache.get_header(
97+
model_id=model_id, semantic_version_str=version
98+
)
9799

98100
@staticmethod
99101
def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelSpecs:
@@ -109,7 +111,9 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
109111
)
110112
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
111113
assert JumpStartModelsAccessor._cache is not None
112-
return JumpStartModelsAccessor._cache.get_specs(model_id, version)
114+
return JumpStartModelsAccessor._cache.get_specs(
115+
model_id=model_id, semantic_version_str=version
116+
)
113117

114118
@staticmethod
115119
def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:

src/sagemaker/jumpstart/artifacts.py

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains functions for obtaining JumpStart ECR and S3 URIs."""
1414
from __future__ import absolute_import
15-
from typing import Optional
15+
from typing import Dict, Optional
1616
from sagemaker import image_uris
1717
from sagemaker.jumpstart.constants import (
1818
JUMPSTART_DEFAULT_REGION_NAME,
1919
INFERENCE,
2020
TRAINING,
2121
SUPPORTED_JUMPSTART_SCOPES,
2222
ModelFramework,
23+
VariableScope,
2324
)
2425
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
2526
from sagemaker.jumpstart import accessors as jumpstart_accessors
@@ -93,7 +94,7 @@ def _retrieve_image_uri(
9394
)
9495

9596
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
96-
region, model_id, model_version
97+
region=region, model_id=model_id, version=model_version
9798
)
9899

99100
if image_scope == INFERENCE:
@@ -110,19 +111,19 @@ def _retrieve_image_uri(
110111
if framework is not None and framework != ecr_specs.framework:
111112
raise ValueError(
112113
f"Incorrect container framework '{framework}' for JumpStart model ID '{model_id}' "
113-
f"and version {model_version}'."
114+
f"and version '{model_version}'."
114115
)
115116

116117
if version is not None and version != ecr_specs.framework_version:
117118
raise ValueError(
118119
f"Incorrect container framework version '{version}' for JumpStart model ID "
119-
f"'{model_id}' and version {model_version}'."
120+
f"'{model_id}' and version '{model_version}'."
120121
)
121122

122123
if py_version is not None and py_version != ecr_specs.py_version:
123124
raise ValueError(
124125
f"Incorrect python version '{py_version}' for JumpStart model ID '{model_id}' "
125-
f"and version {model_version}'."
126+
f"and version '{model_version}'."
126127
)
127128

128129
base_framework_version_override: Optional[str] = None
@@ -201,7 +202,7 @@ def _retrieve_model_uri(
201202
)
202203

203204
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
204-
region, model_id, model_version
205+
region=region, model_id=model_id, version=model_version
205206
)
206207
if model_scope == INFERENCE:
207208
model_artifact_key = model_specs.hosting_artifact_key
@@ -260,7 +261,7 @@ def _retrieve_script_uri(
260261
)
261262

262263
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
263-
region, model_id, model_version
264+
region=region, model_id=model_id, version=model_version
264265
)
265266
if script_scope == INFERENCE:
266267
model_script_key = model_specs.hosting_script_key
@@ -278,3 +279,77 @@ def _retrieve_script_uri(
278279
script_s3_uri = f"s3://{bucket}/{model_script_key}"
279280

280281
return script_s3_uri
282+
283+
284+
def _retrieve_default_hyperparameters(
285+
model_id: str,
286+
model_version: str,
287+
region: Optional[str],
288+
include_container_hyperparameters: bool = False,
289+
):
290+
"""Retrieves the training hyperparameters for the model matching the given arguments.
291+
292+
Args:
293+
model_id (str): JumpStart model ID of the JumpStart model for which to
294+
retrieve the default hyperparameters.
295+
model_version (str): Version of the JumpStart model for which to retrieve the
296+
default hyperparameters.
297+
region (str): Region for which to retrieve default hyperparameters.
298+
include_container_hyperparameters (bool): True if container hyperparameters
299+
should be returned as well. Container hyperparameters are not used to tune
300+
the specific algorithm, but rather by SageMaker Training to setup
301+
the training container environment. For example, there is a container hyperparameter
302+
that indicates the entrypoint script to use. These hyperparameters may be required
303+
when creating a training job with boto3, however the ``Estimator`` classes
304+
should take care of adding container hyperparameters to the job. (Default: False).
305+
Returns:
306+
dict: the hyperparameters to use for the model.
307+
"""
308+
309+
if region is None:
310+
region = JUMPSTART_DEFAULT_REGION_NAME
311+
312+
assert region is not None
313+
314+
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
315+
region=region, model_id=model_id, version=model_version
316+
)
317+
318+
default_hyperparameters: Dict[str, str] = {}
319+
for hyperparameter in model_specs.hyperparameters:
320+
if (
321+
include_container_hyperparameters and hyperparameter.scope == VariableScope.CONTAINER
322+
) or hyperparameter.scope == VariableScope.ALGORITHM:
323+
default_hyperparameters[hyperparameter.name] = str(hyperparameter.default)
324+
return default_hyperparameters
325+
326+
327+
def _retrieve_default_environment_variables(
328+
model_id: str,
329+
model_version: str,
330+
region: Optional[str],
331+
):
332+
"""Retrieves the inference environment variables for the model matching the given arguments.
333+
334+
Args:
335+
model_id (str): JumpStart model ID of the JumpStart model for which to
336+
retrieve the default environment variables.
337+
model_version (str): Version of the JumpStart model for which to retrieve the
338+
default environment variables.
339+
region (Optional[str]): Region for which to retrieve default environment variables.
340+
341+
Returns:
342+
dict: the inference environment variables to use for the model.
343+
"""
344+
345+
if region is None:
346+
region = JUMPSTART_DEFAULT_REGION_NAME
347+
348+
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
349+
region=region, model_id=model_id, version=model_version
350+
)
351+
352+
default_environment_variables: Dict[str, str] = {}
353+
for environment_variable in model_specs.inference_environment_variables:
354+
default_environment_variables[environment_variable.name] = str(environment_variable.default)
355+
return default_environment_variables

src/sagemaker/jumpstart/constants.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@
120120
TRAINING = "training"
121121
SUPPORTED_JUMPSTART_SCOPES = set([INFERENCE, TRAINING])
122122

123+
INFERENCE_ENTRYPOINT_SCRIPT_NAME = "inference.py"
124+
TRAINING_ENTRYPOINT_SCRIPT_NAME = "transfer_learning.py"
125+
123126

124127
class ModelFramework(str, Enum):
125128
"""Enum class for JumpStart model framework.
@@ -136,3 +139,13 @@ class ModelFramework(str, Enum):
136139
CATBOOST = "catboost"
137140
XGBOOST = "xgboost"
138141
SKLEARN = "sklearn"
142+
143+
144+
class VariableScope(str, Enum):
145+
"""Possible value of the ``scope`` attribute for a hyperparameter or environment variable.
146+
147+
Used for hosting environment variables and training hyperparameters.
148+
"""
149+
150+
CONTAINER = "container"
151+
ALGORITHM = "algorithm"

0 commit comments

Comments
 (0)