Skip to content

fix: add pytorch 1.8.1 for huggingface #2639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 74 additions & 5 deletions src/sagemaker/image_uri_config/huggingface.json
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "huggingface-pytorch-training"
"repository": "huggingface-pytorch-training",
"container_version": {"gpu":"cu110-ubuntu18.04"}
},
"pytorch1.7.1": {
"py_versions": ["py36"],
Expand Down Expand Up @@ -209,7 +210,40 @@
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "huggingface-pytorch-training"
"repository": "huggingface-pytorch-training",
"container_version": {"gpu":"cu110-ubuntu18.04"}
},
"pytorch1.8.1": {
"py_versions": ["py36"],
"registries": {
"af-south-1": "626614931356",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ca-central-1": "763104351884",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "huggingface-pytorch-training",
"container_version": {"gpu":"cu111-ubuntu18.04"}
},
"tensorflow2.4.1": {
"py_versions": ["py37"],
Expand Down Expand Up @@ -240,7 +274,8 @@
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "huggingface-tensorflow-training"
"repository": "huggingface-tensorflow-training",
"container_version": {"gpu":"cu110-ubuntu18.04"}
}
}
}
Expand Down Expand Up @@ -286,7 +321,40 @@
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "huggingface-pytorch-inference"
"repository": "huggingface-pytorch-inference",
"container_version": {"gpu":"cu110-ubuntu18.04", "cpu":"ubuntu18.04" }
},
"pytorch1.8.1": {
"py_versions": ["py36"],
"registries": {
"af-south-1": "626614931356",
"ap-east-1": "871362719292",
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ca-central-1": "763104351884",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "huggingface-pytorch-inference",
"container_version": {"gpu":"cu111-ubuntu18.04", "cpu":"ubuntu18.04" }
},
"tensorflow2.4.1": {
"py_versions": ["py37"],
Expand Down Expand Up @@ -317,7 +385,8 @@
"us-west-1": "763104351884",
"us-west-2": "763104351884"
},
"repository": "huggingface-tensorflow-inference"
"repository": "huggingface-tensorflow-inference",
"container_version": {"gpu":"cu110-ubuntu18.04", "cpu":"ubuntu18.04" }
}
}
}
Expand Down
27 changes: 26 additions & 1 deletion src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
import logging
import os
import re
import pdb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be removed.


from sagemaker import utils
from sagemaker.spark import defaults
from sagemaker.spark import defaults
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got added twice?



logger = logging.getLogger(__name__)

Expand All @@ -39,7 +42,10 @@ def retrieve(
distribution=None,
base_framework_version=None,
):

"""Retrieves the ECR URI for the Docker image matching the given arguments.
Ideally this function should not be called directly, rather it should be called from the
fit() function inside framework estimator.

Args:
framework (str): The name of the framework or algorithm.
Expand All @@ -56,7 +62,11 @@ def retrieve(
image_scope (str): The image type, i.e. what it is used for.
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
``image_scope`` is ignored.
container_version (str): the version of docker image
container_version (str): the version of docker image.
Ideally the value of parameter is should be created inside the framework.
For custom use, see the list of supported container versions:
https://github.com/aws/deep-learning-containers/blob/master/available_images.md
(default: None).
distribution (dict): A dictionary with information on how to run distributed training
(default: None).

Expand All @@ -66,10 +76,12 @@ def retrieve(
Raises:
ValueError: If the combination of arguments specified is not supported.
"""

config = _config_for_framework_and_scope(framework, image_scope, accelerator_type)
original_version = version
version = _validate_version_and_set_if_needed(version, config, framework)
version_config = config["versions"][_version_for_config(version, config)]

if framework == HUGGING_FACE_FRAMEWORK:
if version_config.get("version_aliases"):
full_base_framework_version = version_config["version_aliases"].get(
Expand All @@ -79,9 +91,12 @@ def retrieve(
_validate_arg(full_base_framework_version, list(version_config.keys()), "base framework")
version_config = version_config.get(full_base_framework_version)


py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)
version_config = version_config.get(py_version) or version_config



registry = _registry_from_region(region, version_config["registries"])
hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"]

Expand All @@ -90,12 +105,16 @@ def retrieve(
processor = _processor(
instance_type, config.get("processors") or version_config.get("processors")
)
#if container version is available in .json file, utilize that
if "container_version" in version_config.keys():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use if version_config.get("container_version"): ?

container_version = version_config['container_version'][processor]

if framework == HUGGING_FACE_FRAMEWORK:
pt_or_tf_version = (
re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2)
)
tag_prefix = f"{pt_or_tf_version}-transformers{original_version}"

else:
tag_prefix = version_config.get("tag_prefix", version)

Expand All @@ -105,6 +124,8 @@ def retrieve(
py_version,
container_version,
)


if _should_auto_select_container_version(instance_type, distribution):
container_versions = {
"tensorflow-2.3-gpu-py37": "cu110-ubuntu18.04-v3",
Expand All @@ -119,8 +140,12 @@ def retrieve(
"pytorch-1.6.0-gpu-py36": "cu110-ubuntu18.04",
"pytorch-1.6-gpu-py3": "cu110-ubuntu18.04-v3",
"pytorch-1.6.0-gpu-py3": "cu110-ubuntu18.04",
"pytorch-1.8.1-gpu-py3": "cu111-ubuntu18.04"
}


key = "-".join([framework, tag])

if key in container_versions:
tag = "-".join([tag, container_versions[key]])

Expand Down