From f7e91ef01cec82e3e158f2f3e8d7d8e7f8351f82 Mon Sep 17 00:00:00 2001 From: Tabassum Date: Tue, 14 Sep 2021 16:14:59 -0700 Subject: [PATCH] added pytorch 1.8.1 for supporting huggingface --- .../image_uri_config/huggingface.json | 79 +++++++++++++++++-- src/sagemaker/image_uris.py | 27 ++++++- 2 files changed, 100 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/image_uri_config/huggingface.json b/src/sagemaker/image_uri_config/huggingface.json index 17d4b38c81..0d6cf52892 100644 --- a/src/sagemaker/image_uri_config/huggingface.json +++ b/src/sagemaker/image_uri_config/huggingface.json @@ -178,7 +178,8 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "huggingface-pytorch-training" + "repository": "huggingface-pytorch-training", + "container_version": {"gpu":"cu110-ubuntu18.04"} }, "pytorch1.7.1": { "py_versions": ["py36"], @@ -209,7 +210,40 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "huggingface-pytorch-training" + "repository": "huggingface-pytorch-training", + "container_version": {"gpu":"cu110-ubuntu18.04"} + }, + "pytorch1.8.1": { + "py_versions": ["py36"], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "huggingface-pytorch-training", + "container_version": {"gpu":"cu111-ubuntu18.04"} }, "tensorflow2.4.1": { "py_versions": ["py37"], @@ -240,7 +274,8 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "huggingface-tensorflow-training" + "repository": "huggingface-tensorflow-training", + "container_version": {"gpu":"cu110-ubuntu18.04"} } } } @@ -286,7 +321,40 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "huggingface-pytorch-inference" + "repository": "huggingface-pytorch-inference", + "container_version": {"gpu":"cu110-ubuntu18.04", "cpu":"ubuntu18.04" } + }, + "pytorch1.8.1": { + "py_versions": ["py36"], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "huggingface-pytorch-inference", + "container_version": {"gpu":"cu111-ubuntu18.04", "cpu":"ubuntu18.04" } }, "tensorflow2.4.1": { "py_versions": ["py37"], @@ -317,7 +385,8 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "huggingface-tensorflow-inference" + "repository": "huggingface-tensorflow-inference", + "container_version": {"gpu":"cu110-ubuntu18.04", "cpu":"ubuntu18.04" } } } } diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 3d511403ed..f02974cb48 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -17,9 +17,12 @@ import logging import os import re +import pdb from sagemaker import utils from sagemaker.spark import defaults +from sagemaker.spark import defaults + logger = logging.getLogger(__name__) @@ -39,7 +42,10 @@ def retrieve( distribution=None, base_framework_version=None, ): + """Retrieves the ECR URI for the Docker image matching the given arguments. + Ideally this function should not be called directly, rather it should be called from the + fit() function inside framework estimator. Args: framework (str): The name of the framework or algorithm. @@ -56,7 +62,11 @@ def retrieve( image_scope (str): The image type, i.e. what it is used for. Valid values: "training", "inference", "eia". If ``accelerator_type`` is set, ``image_scope`` is ignored. - container_version (str): the version of docker image + container_version (str): the version of docker image. + Ideally the value of parameter is should be created inside the framework. + For custom use, see the list of supported container versions: + https://github.com/aws/deep-learning-containers/blob/master/available_images.md + (default: None). distribution (dict): A dictionary with information on how to run distributed training (default: None). @@ -66,10 +76,12 @@ def retrieve( Raises: ValueError: If the combination of arguments specified is not supported. """ + config = _config_for_framework_and_scope(framework, image_scope, accelerator_type) original_version = version version = _validate_version_and_set_if_needed(version, config, framework) version_config = config["versions"][_version_for_config(version, config)] + if framework == HUGGING_FACE_FRAMEWORK: if version_config.get("version_aliases"): full_base_framework_version = version_config["version_aliases"].get( @@ -79,9 +91,12 @@ def retrieve( _validate_arg(full_base_framework_version, list(version_config.keys()), "base framework") version_config = version_config.get(full_base_framework_version) + py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework) version_config = version_config.get(py_version) or version_config + + registry = _registry_from_region(region, version_config["registries"]) hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"] @@ -90,12 +105,16 @@ def retrieve( processor = _processor( instance_type, config.get("processors") or version_config.get("processors") ) + #if container version is available in .json file, utilize that + if "container_version" in version_config.keys(): + container_version = version_config['container_version'][processor] if framework == HUGGING_FACE_FRAMEWORK: pt_or_tf_version = ( re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2) ) tag_prefix = f"{pt_or_tf_version}-transformers{original_version}" + else: tag_prefix = version_config.get("tag_prefix", version) @@ -105,6 +124,8 @@ def retrieve( py_version, container_version, ) + + if _should_auto_select_container_version(instance_type, distribution): container_versions = { "tensorflow-2.3-gpu-py37": "cu110-ubuntu18.04-v3", @@ -119,8 +140,12 @@ def retrieve( "pytorch-1.6.0-gpu-py36": "cu110-ubuntu18.04", "pytorch-1.6-gpu-py3": "cu110-ubuntu18.04-v3", "pytorch-1.6.0-gpu-py3": "cu110-ubuntu18.04", + "pytorch-1.8.1-gpu-py3": "cu111-ubuntu18.04" } + + key = "-".join([framework, tag]) + if key in container_versions: tag = "-".join([tag, container_versions[key]])