Skip to content

feat: jumpstart hyperparameters and env variables #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: feat/jumpstart-retrieve-functions
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ venv/
.docker/
env/
.vscode/
**/tmp
.python-version
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def read_version():
"packaging>=20.0",
"pandas",
"pathos",
"semantic-version",
]

# Specific use case dependencies
Expand Down
53 changes: 53 additions & 0 deletions src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Accessors to retrieve environment variables for hosting containers."""

from __future__ import absolute_import

import logging
from typing import Dict

from sagemaker.jumpstart import utils as jumpstart_utils
from sagemaker.jumpstart import artifacts

logger = logging.getLogger(__name__)


def retrieve_default(
region=None,
model_id=None,
model_version=None,
) -> Dict[str, str]:
"""Retrieves the default container environment variables for the model matching the arguments.

Args:
region (str): Optional. Region for which to retrieve default environment variables.
(Default: None).
model_id (str): Optional. Model ID of the model for which to
retrieve the default environment variables. (Default: None).
model_version (str): Optional. Version of the model for which to retrieve the
default environment variables. (Default: None).
Returns:
dict: the variables to use for the model.

Raises:
ValueError: If the combination of arguments specified is not supported.
"""
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")

# mypy type checking require these assertions
assert model_id is not None
assert model_version is not None

return artifacts._retrieve_default_environment_variables(model_id, model_version, region)
49 changes: 15 additions & 34 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2426,51 +2426,32 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na

return init_params

def training_image_uri(self):
def training_image_uri(self, region=None):
"""Return the Docker image to use for training.

The :meth:`~sagemaker.estimator.EstimatorBase.fit` method, which does
the model training, calls this method to find the image to use for model
training.

Args:
region (str): Optional. AWS region to use for image URI. Default: AWS region associated
with the SageMaker session.

Returns:
str: The URI of the Docker image.
"""
if self.image_uri:
return self.image_uri
if hasattr(self, "distribution"):
distribution = self.distribution # pylint: disable=no-member
else:
distribution = None
compiler_config = getattr(self, "compiler_config", None)

if hasattr(self, "tensorflow_version") or hasattr(self, "pytorch_version"):
processor = image_uris._processor(self.instance_type, ["cpu", "gpu"])
is_native_huggingface_gpu = processor == "gpu" and not compiler_config
container_version = "cu110-ubuntu18.04" if is_native_huggingface_gpu else None
if self.tensorflow_version is not None: # pylint: disable=no-member
base_framework_version = (
f"tensorflow{self.tensorflow_version}" # pylint: disable=no-member
)
else:
base_framework_version = (
f"pytorch{self.pytorch_version}" # pylint: disable=no-member
)
else:
container_version = None
base_framework_version = None

return image_uris.retrieve(
self._framework_name,
self.sagemaker_session.boto_region_name,
instance_type=self.instance_type,
version=self.framework_version, # pylint: disable=no-member
return image_uris.get_training_image_uri(
region=region or self.sagemaker_session.boto_region_name,
framework=self._framework_name,
framework_version=self.framework_version, # pylint: disable=no-member
py_version=self.py_version, # pylint: disable=no-member
image_scope="training",
distribution=distribution,
base_framework_version=base_framework_version,
container_version=container_version,
training_compiler_config=compiler_config,
image_uri=self.image_uri,
distribution=getattr(self, "distribution", None),
compiler_config=getattr(self, "compiler_config", None),
tensorflow_version=getattr(self, "tensorflow_version", None),
pytorch_version=getattr(self, "pytorch_version", None),
instance_type=self.instance_type,
)

@classmethod
Expand Down
58 changes: 58 additions & 0 deletions src/sagemaker/hyperparameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Accessors to retrieve hyperparameters for training jobs."""

from __future__ import absolute_import

import logging
from typing import Dict

from sagemaker.jumpstart import utils as jumpstart_utils
from sagemaker.jumpstart import artifacts

logger = logging.getLogger(__name__)


def retrieve_default(
region=None,
model_id=None,
model_version=None,
include_container_hyperparameters=False,
) -> Dict[str, str]:
"""Retrieves the default training hyperparameters for the model matching the given arguments.

Args:
region (str): Region for which to retrieve default hyperparameters. (Default: None).
model_id (str): Model ID of the model for which to
retrieve the default hyperparameters. (Default: None).
model_version (str): Version of the model for which to retrieve the
default hyperparameters. (Default: None).
include_container_hyperparameters (bool): True if container hyperparameters
should be returned as well. Container hyperparameters are not used to tune
the specific algorithm, but rather by SageMaker Training to setup
the training container environment. For example, there is a container hyperparameter
that indicates the entrypoint script to use. These hyperparameters may be required
when creating a training job with boto3, however the ``Estimator`` classes
should take care of adding container hyperparameters to the job. (Default: False).
Returns:
dict: the hyperparameters to use for the model.

Raises:
ValueError: If the combination of arguments specified is not supported.
"""
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
raise ValueError("Must specify `model_id` and `model_version` when retrieving script URIs.")

return artifacts._retrieve_default_hyperparameters(
model_id, model_version, region, include_container_hyperparameters
)
99 changes: 98 additions & 1 deletion src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@
import logging
import os
import re
from typing import Optional

from sagemaker import utils
from sagemaker.jumpstart.utils import is_jumpstart_model_input
from sagemaker.spark import defaults
from sagemaker.jumpstart import artifacts


logger = logging.getLogger(__name__)

Expand All @@ -39,7 +43,9 @@ def retrieve(
distribution=None,
base_framework_version=None,
training_compiler_config=None,
):
model_id=None,
model_version=None,
) -> str:
"""Retrieves the ECR URI for the Docker image matching the given arguments.

Ideally this function should not be called directly, rather it should be called from the
Expand Down Expand Up @@ -69,13 +75,39 @@ def retrieve(
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
A configuration class for the SageMaker Training Compiler
(default: None).
model_id (str): JumpStart model ID for which to retrieve image URI
(default: None).
model_version (str): Version of the JumpStart model for which to retrieve the
image URI (default: None).

Returns:
str: the ECR URI for the corresponding SageMaker Docker image.

Raises:
ValueError: If the combination of arguments specified is not supported.
"""
if is_jumpstart_model_input(model_id, model_version):

# adding assert statements to satisfy mypy type checker
assert model_id is not None
assert model_version is not None

return artifacts._retrieve_image_uri(
model_id,
model_version,
image_scope,
framework,
region,
version,
py_version,
instance_type,
accelerator_type,
container_version,
distribution,
base_framework_version,
training_compiler_config,
)

if training_compiler_config is None:
config = _config_for_framework_and_scope(framework, image_scope, accelerator_type)
elif framework == HUGGING_FACE_FRAMEWORK:
Expand Down Expand Up @@ -347,3 +379,68 @@ def _validate_arg(arg, available_options, arg_name):
def _format_tag(tag_prefix, processor, py_version, container_version):
"""Creates a tag for the image URI."""
return "-".join(x for x in (tag_prefix, processor, py_version, container_version) if x)


def get_training_image_uri(
region,
framework,
framework_version=None,
py_version=None,
image_uri=None,
distribution=None,
compiler_config=None,
tensorflow_version=None,
pytorch_version=None,
instance_type=None,
) -> str:
"""Retrieve image uri for training.

Args:
region (str): AWS region to use for image URI.
framework (str): The framework for which to retrieve an image URI.
framework_version (str): The framework version for which to retrieve an
image URI (default: None).
py_version (str): The python version to use for the image (default: None).
image_uri (str): If an image URI is supplied, it will be returned (default: None).
distribution (dict): A dictionary with information on how to run distributed
training (default: None).
compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
A configuration class for the SageMaker Training Compiler
(default: None).
tensorflow_version (str): Version of tensorflow to use. (default: None)
pytorch_version (str): Version of pytorch to use. (default: None)
instance_type (str): Instance type fo use. (default: None)

Returns:
str: the image URI string.
"""

if image_uri:
return image_uri

base_framework_version: Optional[str] = None

if tensorflow_version is not None or pytorch_version is not None:
processor = _processor(instance_type, ["cpu", "gpu"])
is_native_huggingface_gpu = processor == "gpu" and not compiler_config
container_version = "cu110-ubuntu18.04" if is_native_huggingface_gpu else None
if tensorflow_version is not None:
base_framework_version = f"tensorflow{tensorflow_version}"
else:
base_framework_version = f"pytorch{pytorch_version}"
else:
container_version = None
base_framework_version = None

return retrieve(
framework,
region,
instance_type=instance_type,
version=framework_version,
py_version=py_version,
image_scope="training",
distribution=distribution,
base_framework_version=base_framework_version,
container_version=container_version,
training_compiler_config=compiler_config,
)
Loading