Skip to content

Commit f7e91ef

Browse files
author
Tabassum
committed
added pytorch 1.8.1 for supporting huggingface
1 parent 8447430 commit f7e91ef

File tree

2 files changed

+100
-6
lines changed

2 files changed

+100
-6
lines changed

src/sagemaker/image_uri_config/huggingface.json

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@
178178
"us-west-1": "763104351884",
179179
"us-west-2": "763104351884"
180180
},
181-
"repository": "huggingface-pytorch-training"
181+
"repository": "huggingface-pytorch-training",
182+
"container_version": {"gpu":"cu110-ubuntu18.04"}
182183
},
183184
"pytorch1.7.1": {
184185
"py_versions": ["py36"],
@@ -209,7 +210,40 @@
209210
"us-west-1": "763104351884",
210211
"us-west-2": "763104351884"
211212
},
212-
"repository": "huggingface-pytorch-training"
213+
"repository": "huggingface-pytorch-training",
214+
"container_version": {"gpu":"cu110-ubuntu18.04"}
215+
},
216+
"pytorch1.8.1": {
217+
"py_versions": ["py36"],
218+
"registries": {
219+
"af-south-1": "626614931356",
220+
"ap-east-1": "871362719292",
221+
"ap-northeast-1": "763104351884",
222+
"ap-northeast-2": "763104351884",
223+
"ap-northeast-3": "364406365360",
224+
"ap-south-1": "763104351884",
225+
"ap-southeast-1": "763104351884",
226+
"ap-southeast-2": "763104351884",
227+
"ca-central-1": "763104351884",
228+
"cn-north-1": "727897471807",
229+
"cn-northwest-1": "727897471807",
230+
"eu-central-1": "763104351884",
231+
"eu-north-1": "763104351884",
232+
"eu-west-1": "763104351884",
233+
"eu-west-2": "763104351884",
234+
"eu-west-3": "763104351884",
235+
"eu-south-1": "692866216735",
236+
"me-south-1": "217643126080",
237+
"sa-east-1": "763104351884",
238+
"us-east-1": "763104351884",
239+
"us-east-2": "763104351884",
240+
"us-gov-west-1": "442386744353",
241+
"us-iso-east-1": "886529160074",
242+
"us-west-1": "763104351884",
243+
"us-west-2": "763104351884"
244+
},
245+
"repository": "huggingface-pytorch-training",
246+
"container_version": {"gpu":"cu111-ubuntu18.04"}
213247
},
214248
"tensorflow2.4.1": {
215249
"py_versions": ["py37"],
@@ -240,7 +274,8 @@
240274
"us-west-1": "763104351884",
241275
"us-west-2": "763104351884"
242276
},
243-
"repository": "huggingface-tensorflow-training"
277+
"repository": "huggingface-tensorflow-training",
278+
"container_version": {"gpu":"cu110-ubuntu18.04"}
244279
}
245280
}
246281
}
@@ -286,7 +321,40 @@
286321
"us-west-1": "763104351884",
287322
"us-west-2": "763104351884"
288323
},
289-
"repository": "huggingface-pytorch-inference"
324+
"repository": "huggingface-pytorch-inference",
325+
"container_version": {"gpu":"cu110-ubuntu18.04", "cpu":"ubuntu18.04" }
326+
},
327+
"pytorch1.8.1": {
328+
"py_versions": ["py36"],
329+
"registries": {
330+
"af-south-1": "626614931356",
331+
"ap-east-1": "871362719292",
332+
"ap-northeast-1": "763104351884",
333+
"ap-northeast-2": "763104351884",
334+
"ap-northeast-3": "364406365360",
335+
"ap-south-1": "763104351884",
336+
"ap-southeast-1": "763104351884",
337+
"ap-southeast-2": "763104351884",
338+
"ca-central-1": "763104351884",
339+
"cn-north-1": "727897471807",
340+
"cn-northwest-1": "727897471807",
341+
"eu-central-1": "763104351884",
342+
"eu-north-1": "763104351884",
343+
"eu-west-1": "763104351884",
344+
"eu-west-2": "763104351884",
345+
"eu-west-3": "763104351884",
346+
"eu-south-1": "692866216735",
347+
"me-south-1": "217643126080",
348+
"sa-east-1": "763104351884",
349+
"us-east-1": "763104351884",
350+
"us-east-2": "763104351884",
351+
"us-gov-west-1": "442386744353",
352+
"us-iso-east-1": "886529160074",
353+
"us-west-1": "763104351884",
354+
"us-west-2": "763104351884"
355+
},
356+
"repository": "huggingface-pytorch-inference",
357+
"container_version": {"gpu":"cu111-ubuntu18.04", "cpu":"ubuntu18.04" }
290358
},
291359
"tensorflow2.4.1": {
292360
"py_versions": ["py37"],
@@ -317,7 +385,8 @@
317385
"us-west-1": "763104351884",
318386
"us-west-2": "763104351884"
319387
},
320-
"repository": "huggingface-tensorflow-inference"
388+
"repository": "huggingface-tensorflow-inference",
389+
"container_version": {"gpu":"cu110-ubuntu18.04", "cpu":"ubuntu18.04" }
321390
}
322391
}
323392
}

src/sagemaker/image_uris.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717
import logging
1818
import os
1919
import re
20+
import pdb
2021

2122
from sagemaker import utils
2223
from sagemaker.spark import defaults
24+
from sagemaker.spark import defaults
25+
2326

2427
logger = logging.getLogger(__name__)
2528

@@ -39,7 +42,10 @@ def retrieve(
3942
distribution=None,
4043
base_framework_version=None,
4144
):
45+
4246
"""Retrieves the ECR URI for the Docker image matching the given arguments.
47+
Ideally this function should not be called directly, rather it should be called from the
48+
fit() function inside framework estimator.
4349
4450
Args:
4551
framework (str): The name of the framework or algorithm.
@@ -56,7 +62,11 @@ def retrieve(
5662
image_scope (str): The image type, i.e. what it is used for.
5763
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
5864
``image_scope`` is ignored.
59-
container_version (str): the version of docker image
65+
container_version (str): the version of docker image.
66+
Ideally the value of parameter is should be created inside the framework.
67+
For custom use, see the list of supported container versions:
68+
https://github.com/aws/deep-learning-containers/blob/master/available_images.md
69+
(default: None).
6070
distribution (dict): A dictionary with information on how to run distributed training
6171
(default: None).
6272
@@ -66,10 +76,12 @@ def retrieve(
6676
Raises:
6777
ValueError: If the combination of arguments specified is not supported.
6878
"""
79+
6980
config = _config_for_framework_and_scope(framework, image_scope, accelerator_type)
7081
original_version = version
7182
version = _validate_version_and_set_if_needed(version, config, framework)
7283
version_config = config["versions"][_version_for_config(version, config)]
84+
7385
if framework == HUGGING_FACE_FRAMEWORK:
7486
if version_config.get("version_aliases"):
7587
full_base_framework_version = version_config["version_aliases"].get(
@@ -79,9 +91,12 @@ def retrieve(
7991
_validate_arg(full_base_framework_version, list(version_config.keys()), "base framework")
8092
version_config = version_config.get(full_base_framework_version)
8193

94+
8295
py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)
8396
version_config = version_config.get(py_version) or version_config
8497

98+
99+
85100
registry = _registry_from_region(region, version_config["registries"])
86101
hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"]
87102

@@ -90,12 +105,16 @@ def retrieve(
90105
processor = _processor(
91106
instance_type, config.get("processors") or version_config.get("processors")
92107
)
108+
#if container version is available in .json file, utilize that
109+
if "container_version" in version_config.keys():
110+
container_version = version_config['container_version'][processor]
93111

94112
if framework == HUGGING_FACE_FRAMEWORK:
95113
pt_or_tf_version = (
96114
re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2)
97115
)
98116
tag_prefix = f"{pt_or_tf_version}-transformers{original_version}"
117+
99118
else:
100119
tag_prefix = version_config.get("tag_prefix", version)
101120

@@ -105,6 +124,8 @@ def retrieve(
105124
py_version,
106125
container_version,
107126
)
127+
128+
108129
if _should_auto_select_container_version(instance_type, distribution):
109130
container_versions = {
110131
"tensorflow-2.3-gpu-py37": "cu110-ubuntu18.04-v3",
@@ -119,8 +140,12 @@ def retrieve(
119140
"pytorch-1.6.0-gpu-py36": "cu110-ubuntu18.04",
120141
"pytorch-1.6-gpu-py3": "cu110-ubuntu18.04-v3",
121142
"pytorch-1.6.0-gpu-py3": "cu110-ubuntu18.04",
143+
"pytorch-1.8.1-gpu-py3": "cu111-ubuntu18.04"
122144
}
145+
146+
123147
key = "-".join([framework, tag])
148+
124149
if key in container_versions:
125150
tag = "-".join([tag, container_versions[key]])
126151

0 commit comments

Comments
 (0)