Skip to content

breaking: rename image_name to image_uri #1667

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 6, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(
hyperparameters=None,
framework_version=None,
py_version=None,
image_name=None,
image_uri=None,
**kwargs
):
"""This ``Estimator`` executes an Chainer script in a managed Chainer
Expand Down Expand Up @@ -101,13 +101,13 @@ def __init__(
and values, but ``str()`` will be called to convert them before
training.
py_version (str): Python version you want to use for executing your
model training code. Defaults to ``None``. Required unless ``image_name``
model training code. Defaults to ``None``. Required unless ``image_uri``
is provided.
framework_version (str): Chainer version you want to use for
executing your model training code. Defaults to ``None``. Required unless
``image_name`` is provided. List of supported versions:
``image_uri`` is provided. List of supported versions:
https://github.com/aws/sagemaker-python-sdk#chainer-sagemaker-estimators.
image_name (str): If specified, the estimator will use this image
image_uri (str): If specified, the estimator will use this image
for training and hosting, instead of selecting the appropriate
SageMaker official image based on framework_version and
py_version. It can be an ECR url or dockerhub image and tag.
Expand All @@ -117,7 +117,7 @@ def __init__(
* ``custom-image:latest``

If ``framework_version`` or ``py_version`` are ``None``, then
``image_name`` is required. If also ``None``, then a ``ValueError``
``image_uri`` is required. If also ``None``, then a ``ValueError``
will be raised.
**kwargs: Additional kwargs passed to the
:class:`~sagemaker.estimator.Framework` constructor.
Expand All @@ -128,7 +128,7 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
validate_version_or_image_args(framework_version, py_version, image_name)
validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
Expand All @@ -137,7 +137,7 @@ def __init__(
self.py_version = py_version

super(Chainer, self).__init__(
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
)

self.use_mpi = use_mpi
Expand Down Expand Up @@ -209,7 +209,7 @@ def create_model(
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))

if "image" not in kwargs:
kwargs["image"] = self.image_name
kwargs["image"] = self.image_uri

return ChainerModel(
self.model_data,
Expand Down Expand Up @@ -257,8 +257,8 @@ class constructor
if value:
init_params[argument[len("sagemaker_") :]] = value

image_name = init_params.pop("image")
framework, py_version, tag, _ = framework_name_from_image(image_name)
image_uri = init_params.pop("image")
framework, py_version, tag, _ = framework_name_from_image(image_uri)

if tag is None:
framework_version = None
Expand All @@ -270,7 +270,7 @@ class constructor
if not framework:
# If we were unable to parse the framework name from the image it is not one of our
# officially supported images, in this case just add the image to the init params.
init_params["image_name"] = image_name
init_params["image_uri"] = image_uri
return init_params

if framework != cls.__framework_name__:
Expand Down
28 changes: 12 additions & 16 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
UploadedCode,
validate_source_dir,
_region_supports_debugger,
parameter_v2_rename_warning,
)
from sagemaker.job import _Job
from sagemaker.local import LocalSession
Expand Down Expand Up @@ -1131,7 +1130,7 @@ class Estimator(EstimatorBase):

def __init__(
self,
image_name,
image_uri,
role,
train_instance_count,
train_instance_type,
Expand Down Expand Up @@ -1164,7 +1163,7 @@ def __init__(
"""Initialize an ``Estimator`` instance.

Args:
image_name (str): The container image to use for training.
image_uri (str): The container image to use for training.
role (str): An AWS IAM role (either name or full ARN). The Amazon
SageMaker training jobs and APIs that create Amazon SageMaker
endpoints use this role to access training data and model
Expand Down Expand Up @@ -1273,8 +1272,7 @@ def __init__(
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
(default: ``None``).
"""
logging.warning(parameter_v2_rename_warning("image_name", "image_uri"))
self.image_name = image_name
self.image_uri = image_uri
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
super(Estimator, self).__init__(
role,
Expand Down Expand Up @@ -1312,7 +1310,7 @@ def train_image(self):
The fit() method, that does the model training, calls this method to
find the image to use for model training.
"""
return self.image_name
return self.image_uri

def set_hyperparameters(self, **kwargs):
"""
Expand Down Expand Up @@ -1422,7 +1420,7 @@ class constructor
job_details, model_channel_name
)

init_params["image_name"] = init_params.pop("image")
init_params["image_uri"] = init_params.pop("image")
return init_params


Expand All @@ -1449,7 +1447,7 @@ def __init__(
enable_cloudwatch_metrics=False,
container_log_level=logging.INFO,
code_location=None,
image_name=None,
image_uri=None,
dependencies=None,
enable_network_isolation=False,
git_config=None,
Expand Down Expand Up @@ -1515,7 +1513,7 @@ def __init__(
a string prepended with a "/" is appended to ``code_location``. The code
file uploaded to S3 is 'code_location/job-name/source/sourcedir.tar.gz'.
If not specified, the default ``code location`` is s3://output_bucket/job-name/.
image_name (str): An alternate image name to use instead of the
image_uri (str): An alternate image name to use instead of the
official Sagemaker image for the framework. This is useful to
run one of the Sagemaker supported frameworks with an image
containing custom dependencies.
Expand Down Expand Up @@ -1635,6 +1633,8 @@ def __init__(
self.git_config = git_config
self.source_dir = source_dir
self.dependencies = dependencies or []
self.uploaded_code = None

if enable_cloudwatch_metrics:
warnings.warn(
"enable_cloudwatch_metrics is now deprecated and will be removed in the future.",
Expand All @@ -1643,11 +1643,7 @@ def __init__(
self.enable_cloudwatch_metrics = False
self.container_log_level = container_log_level
self.code_location = code_location
self.image_name = image_name
if image_name is not None:
logging.warning(parameter_v2_rename_warning("image_name", "image_uri"))

self.uploaded_code = None
self.image_uri = image_uri

self._hyperparameters = hyperparameters or {}
self.checkpoint_s3_uri = checkpoint_s3_uri
Expand Down Expand Up @@ -1833,8 +1829,8 @@ def train_image(self):
Returns:
str: The URI of the Docker image.
"""
if self.image_name:
return self.image_name
if self.image_uri:
return self.image_uri
return create_image_uri(
self.sagemaker_session.boto_region_name,
self.__framework_name__,
Expand Down
18 changes: 9 additions & 9 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,11 +490,11 @@ def _list_files_to_compress(script, directory):
return [os.path.join(basedir, name) for name in os.listdir(basedir)]


def framework_name_from_image(image_name):
def framework_name_from_image(image_uri):
# noinspection LongLine
"""Extract the framework and Python version from the image name.
Args:
image_name (str): Image URI, which should be one of the following forms:
image_uri (str): Image URI, which should be one of the following forms:
legacy:
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>-<py_ver>-<device>:<container_version>'
legacy:
Expand All @@ -509,7 +509,7 @@ def framework_name_from_image(image_name):
str: If the image is script mode
"""
sagemaker_pattern = re.compile(ECR_URI_PATTERN)
sagemaker_match = sagemaker_pattern.match(image_name)
sagemaker_match = sagemaker_pattern.match(image_uri)
if sagemaker_match is None:
return None, None, None, None
# extract framework, python version and image tag
Expand Down Expand Up @@ -691,22 +691,22 @@ def _region_supports_debugger(region_name):
return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS


def validate_version_or_image_args(framework_version, py_version, image_name):
def validate_version_or_image_args(framework_version, py_version, image_uri):
"""Checks if version or image arguments are specified.

Validates framework and model arguments to enforce version or image specification.

Args:
framework_version (str): The version of the framework.
py_version (str): The version of Python.
image_name (str): The URI of the image.
image_uri (str): The URI of the image.

Raises:
ValueError: if `image_name` is None and either `framework_version` or `py_version` is
ValueError: if `image_uri` is None and either `framework_version` or `py_version` is
None.
"""
if (framework_version is None or py_version is None) and image_name is None:
if (framework_version is None or py_version is None) and image_uri is None:
raise ValueError(
"framework_version or py_version was None, yet image_name was also None. "
"Either specify both framework_version and py_version, or specify image_name."
"framework_version or py_version was None, yet image_uri was also None. "
"Either specify both framework_version and py_version, or specify image_uri."
)
26 changes: 13 additions & 13 deletions src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
py_version=None,
source_dir=None,
hyperparameters=None,
image_name=None,
image_uri=None,
distribution=None,
**kwargs
):
Expand Down Expand Up @@ -72,11 +72,11 @@ def __init__(
must point to a file located at the root of ``source_dir``.
framework_version (str): MXNet version you want to use for executing
your model training code. Defaults to `None`. Required unless
``image_name`` is provided. List of supported versions.
``image_uri`` is provided. List of supported versions.
https://github.com/aws/sagemaker-python-sdk#mxnet-sagemaker-estimators.
py_version (str): Python version you want to use for executing your
model training code. One of 'py2' or 'py3'. Defaults to ``None``. Required
unless ``image_name`` is provided.
unless ``image_uri`` is provided.
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
with any other training source code dependencies aside from the entry
point file (default: None). If ``source_dir`` is an S3 URI, it must
Expand All @@ -88,7 +88,7 @@ def __init__(
SageMaker. For convenience, this accepts other types for keys
and values, but ``str()`` will be called to convert them before
training.
image_name (str): If specified, the estimator will use this image for training and
image_uri (str): If specified, the estimator will use this image for training and
hosting, instead of selecting the appropriate SageMaker official image based on
framework_version and py_version. It can be an ECR url or dockerhub image and tag.

Expand All @@ -97,7 +97,7 @@ def __init__(
* ``custom-image:latest``

If ``framework_version`` or ``py_version`` are ``None``, then
``image_name`` is required. If also ``None``, then a ``ValueError``
``image_uri`` is required. If also ``None``, then a ``ValueError``
will be raised.
distribution (dict): A dictionary with information on how to run distributed
training (default: None). To have parameter servers launched for training,
Expand All @@ -111,7 +111,7 @@ def __init__(
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
validate_version_or_image_args(framework_version, py_version, image_name)
validate_version_or_image_args(framework_version, py_version, image_uri)
if py_version == "py2":
logger.warning(
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
Expand All @@ -127,7 +127,7 @@ def __init__(
kwargs["enable_sagemaker_metrics"] = True

super(MXNet, self).__init__(
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
)

if distribution is not None:
Expand Down Expand Up @@ -168,7 +168,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
image_name=None,
image_uri=None,
**kwargs
):
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an
Expand Down Expand Up @@ -198,7 +198,7 @@ def create_model(
any additional libraries that will be exported to the container.
If not specified, the dependencies from training are used.
This is not supported with "local code" in Local Mode.
image_name (str): If specified, the estimator will use this image for hosting, instead
image_uri (str): If specified, the estimator will use this image for hosting, instead
of selecting the appropriate SageMaker official image based on framework_version
and py_version. It can be an ECR url or dockerhub image and tag.

Expand All @@ -214,7 +214,7 @@ def create_model(
See :func:`~sagemaker.mxnet.model.MXNetModel` for full details.
"""
if "image" not in kwargs:
kwargs["image"] = image_name or self.image_name
kwargs["image"] = image_uri or self.image_uri

kwargs["name"] = self._get_or_create_name(kwargs.get("name"))

Expand Down Expand Up @@ -252,8 +252,8 @@ class constructor
init_params = super(MXNet, cls)._prepare_init_params_from_job_description(
job_details, model_channel_name
)
image_name = init_params.pop("image")
framework, py_version, tag, _ = framework_name_from_image(image_name)
image_uri = init_params.pop("image")
framework, py_version, tag, _ = framework_name_from_image(image_uri)

# We switched image tagging scheme from regular image version (e.g. '1.0') to more
# expressive containing framework version, device type and python version
Expand All @@ -271,7 +271,7 @@ class constructor
if not framework:
# If we were unable to parse the framework name from the image it is not one of our
# officially supported images, in this case just add the image to the init params.
init_params["image_name"] = image_name
init_params["image_uri"] = image_uri
return init_params

if framework != cls.__framework_name__:
Expand Down
Loading