Skip to content

Commit 5eea777

Browse files
committed
feat: jumpstart hyperparameters and env variables
1 parent 64b51d6 commit 5eea777

File tree

25 files changed

+1604
-117
lines changed

25 files changed

+1604
-117
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Accessors to retrieve environment variables to run pretrained ML models."""
14+
15+
from __future__ import absolute_import
16+
17+
import logging
18+
from typing import Dict
19+
20+
from sagemaker.jumpstart import utils as jumpstart_utils
21+
from sagemaker.jumpstart import artifacts
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
def retrieve_default(
27+
region=None,
28+
model_id=None,
29+
model_version=None,
30+
) -> Dict[str, str]:
31+
"""Retrieves the default environment variables for the model matching the given arguments.
32+
33+
Args:
34+
region (str): Region for which to retrieve default environment variables.
35+
model_id (str): JumpStart model ID of the JumpStart model for which to
36+
retrieve the default environment variables.
37+
model_version (str): Version of the JumpStart model for which to retrieve the
38+
default environment variables.
39+
Returns:
40+
dict: the variables to use for the model.
41+
42+
Raises:
43+
ValueError: If the combination of arguments specified is not supported.
44+
"""
45+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
46+
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
47+
48+
# mypy type checking require these assertions
49+
assert model_id is not None
50+
assert model_version is not None
51+
52+
return artifacts._retrieve_default_environment_variables(model_id, model_version, region)

src/sagemaker/hyperparameters.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Accessors to retrieve hyperparameters to run pretrained ML models."""
14+
15+
from __future__ import absolute_import
16+
17+
import logging
18+
from typing import Dict
19+
20+
from sagemaker.jumpstart import utils as jumpstart_utils
21+
from sagemaker.jumpstart import artifacts
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
def retrieve_default(
27+
region=None,
28+
model_id=None,
29+
model_version=None,
30+
include_container_hyperparameters=False,
31+
) -> Dict[str, str]:
32+
"""Retrieves the default training hyperparameters for the model matching the given arguments.
33+
34+
Args:
35+
region (str): Region for which to retrieve default hyperparameters.
36+
model_id (str): JumpStart model ID of the JumpStart model for which to
37+
retrieve the default hyperparameters.
38+
model_version (str): Version of the JumpStart model for which to retrieve the
39+
default hyperparameters.
40+
include_container_hyperparameters (bool): True if container hyperparameters
41+
should be returned as well. (Default: False)
42+
Returns:
43+
dict: the hyperparameters to use for the model.
44+
45+
Raises:
46+
ValueError: If the combination of arguments specified is not supported.
47+
"""
48+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
49+
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
50+
51+
# mypy type checking require these assertions
52+
assert model_id is not None
53+
assert model_version is not None
54+
55+
return artifacts._retrieve_default_hyperparameters(
56+
model_id, model_version, region, include_container_hyperparameters
57+
)

src/sagemaker/jumpstart/accessors.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
8181
JumpStartModelsCache._cache = cache.JumpStartModelsCache(region=region, **cache_kwargs)
8282
JumpStartModelsCache._curr_region = region
8383
assert JumpStartModelsCache._cache is not None
84-
return JumpStartModelsCache._cache.get_header(model_id, version)
84+
return JumpStartModelsCache._cache.get_header(
85+
model_id=model_id, semantic_version_str=version
86+
)
8587

8688
@staticmethod
8789
def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelSpecs:
@@ -99,7 +101,9 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
99101
JumpStartModelsCache._cache = cache.JumpStartModelsCache(region=region, **cache_kwargs)
100102
JumpStartModelsCache._curr_region = region
101103
assert JumpStartModelsCache._cache is not None
102-
return JumpStartModelsCache._cache.get_specs(model_id, version)
104+
return JumpStartModelsCache._cache.get_specs(
105+
model_id=model_id, semantic_version_str=version
106+
)
103107

104108
@staticmethod
105109
def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:

src/sagemaker/jumpstart/artifacts.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _retrieve_image_uri(
9595
)
9696

9797
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
98-
region, model_id, model_version
98+
region=region, model_id=model_id, version=model_version
9999
)
100100

101101
if image_scope == INFERENCE:
@@ -203,7 +203,7 @@ def _retrieve_model_uri(
203203
)
204204

205205
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
206-
region, model_id, model_version
206+
region=region, model_id=model_id, version=model_version
207207
)
208208
if model_scope == INFERENCE:
209209
model_artifact_key = model_specs.hosting_artifact_key
@@ -262,7 +262,7 @@ def _retrieve_script_uri(
262262
)
263263

264264
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
265-
region, model_id, model_version
265+
region=region, model_id=model_id, version=model_version
266266
)
267267
if script_scope == INFERENCE:
268268
model_script_key = model_specs.hosting_script_key
@@ -280,3 +280,80 @@ def _retrieve_script_uri(
280280
script_s3_uri = f"s3://{bucket}/{model_script_key}"
281281

282282
return script_s3_uri
283+
284+
285+
def _retrieve_default_hyperparameters(
286+
model_id: str,
287+
model_version: str,
288+
region: Optional[str],
289+
include_container_hyperparameters: bool = False,
290+
):
291+
"""Retrieves the training hyperparameters for the model matching the given arguments.
292+
293+
Args:
294+
model_id (str): JumpStart model ID of the JumpStart model for which to
295+
retrieve the default hyperparameters.
296+
model_version (str): Version of the JumpStart model for which to retrieve the
297+
default hyperparameters.
298+
region (str): Region for which to retrieve default hyperparameters.
299+
include_container_hyperparameters (bool): True if container hyperparameters
300+
should be returned as well. (Default: False)
301+
Returns:
302+
dict: the hyperparameters to use for the model.
303+
304+
Raises:
305+
ValueError: If the combination of arguments specified is not supported.
306+
"""
307+
308+
if region is None:
309+
region = JUMPSTART_DEFAULT_REGION_NAME
310+
311+
assert region is not None
312+
313+
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
314+
region=region, model_id=model_id, version=model_version
315+
)
316+
317+
default_hyperparameters = {}
318+
for hyperparameter in model_specs.hyperparameters:
319+
if (
320+
include_container_hyperparameters and hyperparameter.scope == "container"
321+
) or hyperparameter.scope == "algorithm":
322+
default_hyperparameters[hyperparameter.name] = str(hyperparameter.default)
323+
return default_hyperparameters
324+
325+
326+
def _retrieve_default_environment_variables(
327+
model_id: str,
328+
model_version: str,
329+
region: Optional[str],
330+
):
331+
"""Retrieves the inference environment variables for the model matching the given arguments.
332+
333+
Args:
334+
model_id (str): JumpStart model ID of the JumpStart model for which to
335+
retrieve the default environment variables.
336+
model_version (str): Version of the JumpStart model for which to retrieve the
337+
default environment variables.
338+
region (str): Region for which to retrieve default environment variables.
339+
340+
Returns:
341+
dict: the inference environment variables to use for the model.
342+
343+
Raises:
344+
ValueError: If the combination of arguments specified is not supported.
345+
"""
346+
347+
if region is None:
348+
region = JUMPSTART_DEFAULT_REGION_NAME
349+
350+
assert region is not None
351+
352+
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
353+
region=region, model_id=model_id, version=model_version
354+
)
355+
356+
default_environment_variables = {}
357+
for environment_variable in model_specs.inference_environment_variables:
358+
default_environment_variables[environment_variable.name] = str(environment_variable.default)
359+
return default_environment_variables

src/sagemaker/jumpstart/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@
120120
TRAINING = "training"
121121
SUPPORTED_JUMPSTART_SCOPES = set([INFERENCE, TRAINING])
122122

123+
INFERENCE_ENTRYPOINT_SCRIPT_NAME = "inference.py"
124+
TRAINING_ENTRYPOINT_SCRIPT_NAME = "transfer_learning.py"
125+
123126

124127
class ModelFramework(str, Enum):
125128
"""Enum class for JumpStart model framework.

src/sagemaker/jumpstart/types.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,93 @@ def to_json(self) -> Dict[str, Any]:
170170
return json_obj
171171

172172

173+
class JumpStartHyperparameter(JumpStartDataHolderType):
174+
"""Data class for JumpStart hyperparameter."""
175+
176+
__slots__ = {
177+
"name",
178+
"type",
179+
"options",
180+
"default",
181+
"scope",
182+
"min",
183+
"max",
184+
}
185+
186+
def __init__(self, spec: Dict[str, Any]):
187+
"""Initializes a JumpStartHyperparameter object from its json representation.
188+
189+
Args:
190+
spec (Dict[str, Any]): Dictionary representation of hyperparameter.
191+
"""
192+
self.from_json(spec)
193+
194+
def from_json(self, json_obj: Dict[str, Any]) -> None:
195+
"""Sets fields in object based on json.
196+
197+
Args:
198+
json_obj (Dict[str, Any]): Dictionary representation of hyperparameter.
199+
"""
200+
201+
self.name = json_obj["name"]
202+
self.type = json_obj["type"]
203+
self.default = json_obj["default"]
204+
self.scope = json_obj["scope"]
205+
206+
options = json_obj.get("options")
207+
if options is not None:
208+
self.options = options
209+
210+
min_val = json_obj.get("min")
211+
if min_val is not None:
212+
self.min = min_val
213+
214+
max_val = json_obj.get("max")
215+
if max_val is not None:
216+
self.max = max_val
217+
218+
def to_json(self) -> Dict[str, Any]:
219+
"""Returns json representation of JumpStartHyperparameter object."""
220+
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
221+
return json_obj
222+
223+
224+
class JumpStartEnvironmentVariable(JumpStartDataHolderType):
225+
"""Data class for JumpStart environment variable."""
226+
227+
__slots__ = {
228+
"name",
229+
"type",
230+
"default",
231+
"scope",
232+
}
233+
234+
def __init__(self, spec: Dict[str, Any]):
235+
"""Initializes a JumpStartEnvironmentVariable object from its json representation.
236+
237+
Args:
238+
spec (Dict[str, Any]): Dictionary representation of environment variable.
239+
"""
240+
self.from_json(spec)
241+
242+
def from_json(self, json_obj: Dict[str, Any]) -> None:
243+
"""Sets fields in object based on json.
244+
245+
Args:
246+
json_obj (Dict[str, Any]): Dictionary representation of environment variable.
247+
"""
248+
249+
self.name = json_obj["name"]
250+
self.type = json_obj["type"]
251+
self.default = json_obj["default"]
252+
self.scope = json_obj["scope"]
253+
254+
def to_json(self) -> Dict[str, Any]:
255+
"""Returns json representation of JumpStartEnvironmentVariable object."""
256+
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
257+
return json_obj
258+
259+
173260
class JumpStartModelSpecs(JumpStartDataHolderType):
174261
"""Data class JumpStart model specs."""
175262

@@ -186,6 +273,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
186273
"training_artifact_key",
187274
"training_script_key",
188275
"hyperparameters",
276+
"inference_environment_variables",
189277
]
190278

191279
def __init__(self, spec: Dict[str, Any]):
@@ -210,22 +298,37 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
210298
self.hosting_artifact_key: str = json_obj["hosting_artifact_key"]
211299
self.hosting_script_key: str = json_obj["hosting_script_key"]
212300
self.training_supported: bool = bool(json_obj["training_supported"])
301+
self.inference_environment_variables = [
302+
JumpStartEnvironmentVariable(env_variable)
303+
for env_variable in json_obj["inference_environment_variables"]
304+
]
213305
if self.training_supported:
214306
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(
215307
json_obj["training_ecr_specs"]
216308
)
217309
self.training_artifact_key: str = json_obj["training_artifact_key"]
218310
self.training_script_key: str = json_obj["training_script_key"]
219-
self.hyperparameters: Dict[str, Any] = json_obj.get("hyperparameters", {})
311+
hyperparameters = json_obj.get("hyperparameters")
312+
if hyperparameters is not None:
313+
self.hyperparameters = [
314+
JumpStartHyperparameter(hyperparameter) for hyperparameter in hyperparameters
315+
]
220316

221317
def to_json(self) -> Dict[str, Any]:
222318
"""Returns json representation of JumpStartModelSpecs object."""
223319
json_obj = {}
224320
for att in self.__slots__:
225321
if hasattr(self, att):
226322
cur_val = getattr(self, att)
227-
if isinstance(cur_val, JumpStartECRSpecs):
323+
if issubclass(type(cur_val), JumpStartDataHolderType):
228324
json_obj[att] = cur_val.to_json()
325+
elif isinstance(cur_val, list):
326+
json_obj[att] = []
327+
for obj in cur_val:
328+
if issubclass(type(obj), JumpStartDataHolderType):
329+
json_obj[att].append(obj.to_json())
330+
else:
331+
json_obj[att].append(obj)
229332
else:
230333
json_obj[att] = cur_val
231334
return json_obj

0 commit comments

Comments
 (0)