Skip to content

feat: tagging jumpstart models #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: feat/script_mode_for_model_class
Choose a base branch
from
10 changes: 5 additions & 5 deletions src/sagemaker/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(
self.validate_train_spec()
self.hyperparameter_definitions = self._parse_hyperparameters()

self.hyperparam_dict = {}
self._hyperparameters = {}
if hyperparameters:
self.set_hyperparameters(**hyperparameters)

Expand Down Expand Up @@ -215,7 +215,7 @@ def set_hyperparameters(self, **kwargs):
"""Placeholder docstring"""
for k, v in kwargs.items():
value = self._validate_and_cast_hyperparameter(k, v)
self.hyperparam_dict[k] = value
self._hyperparameters[k] = value

self._validate_and_set_default_hyperparameters()

Expand All @@ -225,7 +225,7 @@ def hyperparameters(self):
The fit() method, that does the model training, calls this method to
find the hyperparameters you specified.
"""
return self.hyperparam_dict
return self._hyperparameters

def training_image_uri(self):
"""Returns the docker image to use for training.
Expand Down Expand Up @@ -464,10 +464,10 @@ def _validate_and_set_default_hyperparameters(self):
# Check if all the required hyperparameters are set. If there is a default value
# for one, set it.
for name, definition in self.hyperparameter_definitions.items():
if name not in self.hyperparam_dict:
if name not in self._hyperparameters:
spec = definition["spec"]
if "DefaultValue" in spec:
self.hyperparam_dict[name] = spec["DefaultValue"]
self._hyperparameters[name] = spec["DefaultValue"]
elif "IsRequired" in spec and spec["IsRequired"]:
raise ValueError("Required hyperparameter: %s is not set" % name)

Expand Down
6 changes: 4 additions & 2 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import logging

from sagemaker.estimator import Framework
from sagemaker.estimator import Framework, EstimatorBase
from sagemaker.fw_utils import (
framework_name_from_image,
framework_version_from_tag,
Expand Down Expand Up @@ -158,7 +158,9 @@ def hyperparameters(self):

# remove unset keys.
additional_hyperparameters = {k: v for k, v in additional_hyperparameters.items() if v}
hyperparameters.update(Framework._json_encode_hyperparameters(additional_hyperparameters))
hyperparameters.update(
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
)
return hyperparameters

def create_model(
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix)
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
Expand Down
4 changes: 0 additions & 4 deletions src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,4 @@ def retrieve_default(
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")

# mypy type checking require these assertions
assert model_id is not None
assert model_version is not None

return artifacts._retrieve_default_environment_variables(model_id, model_version, region)
524 changes: 420 additions & 104 deletions src/sagemaker/estimator.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions src/sagemaker/huggingface/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re

from sagemaker.deprecations import renamed_kwargs
from sagemaker.estimator import Framework
from sagemaker.estimator import Framework, EstimatorBase
from sagemaker.fw_utils import (
framework_name_from_image,
warn_if_parameter_server_with_multi_gpu,
Expand Down Expand Up @@ -246,13 +246,13 @@ def hyperparameters(self):
distribution=self.distribution
)
hyperparameters.update(
Framework._json_encode_hyperparameters(distributed_training_hyperparameters)
EstimatorBase._json_encode_hyperparameters(distributed_training_hyperparameters)
)

if self.compiler_config:
training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict()
hyperparameters.update(
Framework._json_encode_hyperparameters(training_compiler_hyperparameters)
EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters)
)

return hyperparameters
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def prepare_container_def(self, instance_type=None, accelerator_type=None):
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
self._upload_code(deploy_key_prefix, repack=True)
deploy_env = dict(self.env)
deploy_env.update(self._framework_env_vars())
deploy_env.update(self._script_mode_env_vars())

if self.model_server_workers:
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
Expand Down
19 changes: 15 additions & 4 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def retrieve(
training_compiler_config=None,
model_id=None,
model_version=None,
tolerate_vulnerable_model=False,
tolerate_deprecated_model=False,
) -> str:
"""Retrieves the ECR URI for the Docker image matching the given arguments.

Expand Down Expand Up @@ -79,19 +81,26 @@ def retrieve(
(default: None).
model_version (str): Version of the JumpStart model for which to retrieve the
image URI (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model specifications
should be tolerated (exception not raised). If False, raises an exception if
the script used by this version of the model has dependencies with known security
vulnerabilities. (Default: False).
tolerate_deprecated_model (bool): True if deprecated versions of model specifications
should be tolerated (exception not raised). If False, raises an exception
if the version of the model is deprecated. (Default: False).

Returns:
str: the ECR URI for the corresponding SageMaker Docker image.

Raises:
NotImplementedError: If the scope is not supported.
ValueError: If the combination of arguments specified is not supported.
VulnerableJumpStartModelError: If any of the dependencies required by the script have
known security vulnerabilities.
DeprecatedJumpStartModelError: If the version of the model is deprecated.
"""
if is_jumpstart_model_input(model_id, model_version):

# adding assert statements to satisfy mypy type checker
assert model_id is not None
assert model_version is not None

return artifacts._retrieve_image_uri(
model_id,
model_version,
Expand All @@ -106,6 +115,8 @@ def retrieve(
distribution,
base_framework_version,
training_compiler_config,
tolerate_vulnerable_model,
tolerate_deprecated_model,
)

if training_compiler_config is None:
Expand Down
7 changes: 2 additions & 5 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def _validate_and_mutate_region_cache_kwargs(
region (str): The region to validate along with the kwargs.
"""
cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
assert isinstance(cache_kwargs_dict, dict)
if region is not None and "region" in cache_kwargs_dict:
if region != cache_kwargs_dict["region"]:
raise ValueError(
Expand Down Expand Up @@ -92,8 +91,7 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
JumpStartModelsAccessor._cache_kwargs, region
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
assert JumpStartModelsAccessor._cache is not None
return JumpStartModelsAccessor._cache.get_header(
return JumpStartModelsAccessor._cache.get_header( # type: ignore
model_id=model_id, semantic_version_str=version
)

Expand All @@ -110,8 +108,7 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
JumpStartModelsAccessor._cache_kwargs, region
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)
assert JumpStartModelsAccessor._cache is not None
return JumpStartModelsAccessor._cache.get_specs(
return JumpStartModelsAccessor._cache.get_specs( # type: ignore
model_id=model_id, semantic_version_str=version
)

Expand Down
Loading