Skip to content

Commit 2c09724

Browse files
committed
change: resolve comments
1 parent e14dd61 commit 2c09724

File tree

6 files changed

+45
-23
lines changed

6 files changed

+45
-23
lines changed

src/sagemaker/environment_variables.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Accessors to retrieve environment variables to run pretrained ML models."""
13+
"""Accessors to retrieve environment variables for hosting containers."""
1414

1515
from __future__ import absolute_import
1616

@@ -28,14 +28,15 @@ def retrieve_default(
2828
model_id=None,
2929
model_version=None,
3030
) -> Dict[str, str]:
31-
"""Retrieves the default environment variables for the model matching the given arguments.
31+
"""Retrieves the default container environment variables for the model matching the arguments.
3232
3333
Args:
34-
region (str): Region for which to retrieve default environment variables.
35-
model_id (str): JumpStart model ID of the JumpStart model for which to
36-
retrieve the default environment variables.
37-
model_version (str): Version of the JumpStart model for which to retrieve the
38-
default environment variables.
34+
region (str): Optional. Region for which to retrieve default environment variables.
35+
(Default: None).
36+
model_id (str): Optional. JumpStart model ID of the JumpStart model for which to
37+
retrieve the default environment variables. (Default: None).
38+
model_version (str): Optional. Version of the JumpStart model for which to retrieve the
39+
default environment variables. (Default: None).
3940
Returns:
4041
dict: the variables to use for the model.
4142

src/sagemaker/hyperparameters.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Accessors to retrieve hyperparameters to run pretrained ML models."""
13+
"""Accessors to retrieve hyperparameters for training jobs."""
1414

1515
from __future__ import absolute_import
1616

@@ -38,7 +38,12 @@ def retrieve_default(
3838
model_version (str): Version of the JumpStart model for which to retrieve the
3939
default hyperparameters.
4040
include_container_hyperparameters (bool): True if container hyperparameters
41-
should be returned as well. (Default: False)
41+
should be returned as well. Container hyperparameters are not used to tune
42+
the specific algorithm, but rather by SageMaker Training to setup
43+
the training container environment. For example, there is a container hyperparameter
44+
that indicates the entrypoint script to use. These hyperparameters may be required
45+
when creating a training job with boto3, however the ``Estimator`` classes
46+
should take care of adding container hyperparameters to the job. (Default: False).
4247
Returns:
4348
dict: the hyperparameters to use for the model.
4449
@@ -48,10 +53,6 @@ def retrieve_default(
4853
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
4954
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")
5055

51-
# mypy type checking require these assertions
52-
assert model_id is not None
53-
assert model_version is not None
54-
5556
return artifacts._retrieve_default_hyperparameters(
5657
model_id, model_version, region, include_container_hyperparameters
5758
)

src/sagemaker/jumpstart/artifacts.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains functions for obtaining JumpStart ECR and S3 URIs."""
1414
from __future__ import absolute_import
15-
from typing import Optional
15+
from typing import Dict, Optional
1616
from sagemaker import image_uris
1717
from sagemaker.jumpstart.constants import (
1818
JUMPSTART_DEFAULT_REGION_NAME,
1919
INFERENCE,
2020
TRAINING,
2121
SUPPORTED_JUMPSTART_SCOPES,
2222
ModelFramework,
23+
VariableScope,
2324
)
2425
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
2526
from sagemaker.jumpstart import accessors as jumpstart_accessors
@@ -295,7 +296,12 @@ def _retrieve_default_hyperparameters(
295296
default hyperparameters.
296297
region (str): Region for which to retrieve default hyperparameters.
297298
include_container_hyperparameters (bool): True if container hyperparameters
298-
should be returned as well. (Default: False)
299+
should be returned as well. Container hyperparameters are not used to tune
300+
the specific algorithm, but rather by SageMaker Training to setup
301+
the training container environment. For example, there is a container hyperparameter
302+
that indicates the entrypoint script to use. These hyperparameters may be required
303+
when creating a training job with boto3, however the ``Estimator`` classes
304+
should take care of adding container hyperparameters to the job. (Default: False).
299305
Returns:
300306
dict: the hyperparameters to use for the model.
301307
@@ -312,11 +318,11 @@ def _retrieve_default_hyperparameters(
312318
region=region, model_id=model_id, version=model_version
313319
)
314320

315-
default_hyperparameters = {}
321+
default_hyperparameters: Dict[str, str] = {}
316322
for hyperparameter in model_specs.hyperparameters:
317323
if (
318-
include_container_hyperparameters and hyperparameter.scope == "container"
319-
) or hyperparameter.scope == "algorithm":
324+
include_container_hyperparameters and hyperparameter.scope == VariableScope.CONTAINER
325+
) or hyperparameter.scope == VariableScope.ALGORITHM:
320326
default_hyperparameters[hyperparameter.name] = str(hyperparameter.default)
321327
return default_hyperparameters
322328

@@ -333,7 +339,7 @@ def _retrieve_default_environment_variables(
333339
retrieve the default environment variables.
334340
model_version (str): Version of the JumpStart model for which to retrieve the
335341
default environment variables.
336-
region (str): Region for which to retrieve default environment variables.
342+
region (Optional[str]): Region for which to retrieve default environment variables.
337343
338344
Returns:
339345
dict: the inference environment variables to use for the model.
@@ -345,13 +351,11 @@ def _retrieve_default_environment_variables(
345351
if region is None:
346352
region = JUMPSTART_DEFAULT_REGION_NAME
347353

348-
assert region is not None
349-
350354
model_specs = jumpstart_accessors.JumpStartModelsCache.get_model_specs(
351355
region=region, model_id=model_id, version=model_version
352356
)
353357

354-
default_environment_variables = {}
358+
default_environment_variables: Dict[str, str] = {}
355359
for environment_variable in model_specs.inference_environment_variables:
356360
default_environment_variables[environment_variable.name] = str(environment_variable.default)
357361
return default_environment_variables

src/sagemaker/jumpstart/constants.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,13 @@ class ModelFramework(str, Enum):
139139
CATBOOST = "catboost"
140140
XGBOOST = "xgboost"
141141
SKLEARN = "sklearn"
142+
143+
144+
class VariableScope(str, Enum):
145+
"""Enum class for variable scope.
146+
147+
Used for hosting environment variables and training hyperparameters.
148+
"""
149+
150+
CONTAINER = "container"
151+
ALGORITHM = "algorithm"

src/sagemaker/jumpstart/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
308308
)
309309
self.training_artifact_key: str = json_obj["training_artifact_key"]
310310
self.training_script_key: str = json_obj["training_script_key"]
311-
hyperparameters = json_obj.get("hyperparameters")
311+
hyperparameters: Any = json_obj.get("hyperparameters")
312312
if hyperparameters is not None:
313313
self.hyperparameters = [
314314
JumpStartHyperparameter(hyperparameter) for hyperparameter in hyperparameters

tests/unit/sagemaker/jumpstart/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ def get_header_from_base_header(
3939
version: str = None,
4040
) -> JumpStartModelHeader:
4141

42+
if version and semantic_version_str:
43+
raise ValueError()
44+
4245
if "pytorch" not in model_id and "tensorflow" not in model_id:
4346
raise KeyError("Bad model id")
4447

@@ -72,6 +75,9 @@ def get_spec_from_base_spec(
7275
version: str = None,
7376
) -> JumpStartModelSpecs:
7477

78+
if version and semantic_version_str:
79+
raise ValueError()
80+
7581
if "pytorch" not in model_id and "tensorflow" not in model_id:
7682
raise KeyError("Bad model id")
7783

0 commit comments

Comments
 (0)