Skip to content

Commit 3bb893f

Browse files
jeniyatTabassumahsan-z-khanshreyapandit
authored andcommitted
fix: add pytorch 1.8.1 for huggingface (aws#2642)
* added pytorch 1.8.1 for supporting huggingface * fix: add pytorch 1.8.1 for huggingface * fix: add pytorch 1.8.1 for huggingface * refactored code for flake * refactored code for docstyle * fix: add alias for pytorch 1.8 * update to master * add version alias for pytorch1.8 in huggingface.json * add version alias for pytorch1.8 in huggingface.json * removed empty file, corrected grammar * removed empty line in test_estimator * removed empty line in test_estimator * removed empty line in test_estimator Co-authored-by: Tabassum <[email protected]> Co-authored-by: Ahsan Khan <[email protected]> Co-authored-by: Shreya Pandit <[email protected]>
1 parent 6485a57 commit 3bb893f

File tree

3 files changed

+94
-8
lines changed

3 files changed

+94
-8
lines changed

src/sagemaker/image_uri_config/huggingface.json

+75-5
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@
147147
"version_aliases": {
148148
"pytorch1.6": "pytorch1.6.0",
149149
"pytorch1.7": "pytorch1.7.1",
150+
"pytorch1.8": "pytorch1.8.1",
150151
"tensorflow2.4": "tensorflow2.4.1"
151152
},
152153
"pytorch1.6.0": {
@@ -178,7 +179,8 @@
178179
"us-west-1": "763104351884",
179180
"us-west-2": "763104351884"
180181
},
181-
"repository": "huggingface-pytorch-training"
182+
"repository": "huggingface-pytorch-training",
183+
"container_version": {"gpu":"cu110-ubuntu18.04"}
182184
},
183185
"pytorch1.7.1": {
184186
"py_versions": ["py36"],
@@ -209,7 +211,40 @@
209211
"us-west-1": "763104351884",
210212
"us-west-2": "763104351884"
211213
},
212-
"repository": "huggingface-pytorch-training"
214+
"repository": "huggingface-pytorch-training",
215+
"container_version": {"gpu":"cu110-ubuntu18.04"}
216+
},
217+
"pytorch1.8.1": {
218+
"py_versions": ["py36"],
219+
"registries": {
220+
"af-south-1": "626614931356",
221+
"ap-east-1": "871362719292",
222+
"ap-northeast-1": "763104351884",
223+
"ap-northeast-2": "763104351884",
224+
"ap-northeast-3": "364406365360",
225+
"ap-south-1": "763104351884",
226+
"ap-southeast-1": "763104351884",
227+
"ap-southeast-2": "763104351884",
228+
"ca-central-1": "763104351884",
229+
"cn-north-1": "727897471807",
230+
"cn-northwest-1": "727897471807",
231+
"eu-central-1": "763104351884",
232+
"eu-north-1": "763104351884",
233+
"eu-west-1": "763104351884",
234+
"eu-west-2": "763104351884",
235+
"eu-west-3": "763104351884",
236+
"eu-south-1": "692866216735",
237+
"me-south-1": "217643126080",
238+
"sa-east-1": "763104351884",
239+
"us-east-1": "763104351884",
240+
"us-east-2": "763104351884",
241+
"us-gov-west-1": "442386744353",
242+
"us-iso-east-1": "886529160074",
243+
"us-west-1": "763104351884",
244+
"us-west-2": "763104351884"
245+
},
246+
"repository": "huggingface-pytorch-training",
247+
"container_version": {"gpu":"cu111-ubuntu18.04"}
213248
},
214249
"tensorflow2.4.1": {
215250
"py_versions": ["py37"],
@@ -240,7 +275,8 @@
240275
"us-west-1": "763104351884",
241276
"us-west-2": "763104351884"
242277
},
243-
"repository": "huggingface-tensorflow-training"
278+
"repository": "huggingface-tensorflow-training",
279+
"container_version": {"gpu":"cu110-ubuntu18.04"}
244280
}
245281
}
246282
}
@@ -286,7 +322,40 @@
286322
"us-west-1": "763104351884",
287323
"us-west-2": "763104351884"
288324
},
289-
"repository": "huggingface-pytorch-inference"
325+
"repository": "huggingface-pytorch-inference",
326+
"container_version": {"gpu":"cu110-ubuntu18.04", "cpu":"ubuntu18.04" }
327+
},
328+
"pytorch1.8.1": {
329+
"py_versions": ["py36"],
330+
"registries": {
331+
"af-south-1": "626614931356",
332+
"ap-east-1": "871362719292",
333+
"ap-northeast-1": "763104351884",
334+
"ap-northeast-2": "763104351884",
335+
"ap-northeast-3": "364406365360",
336+
"ap-south-1": "763104351884",
337+
"ap-southeast-1": "763104351884",
338+
"ap-southeast-2": "763104351884",
339+
"ca-central-1": "763104351884",
340+
"cn-north-1": "727897471807",
341+
"cn-northwest-1": "727897471807",
342+
"eu-central-1": "763104351884",
343+
"eu-north-1": "763104351884",
344+
"eu-west-1": "763104351884",
345+
"eu-west-2": "763104351884",
346+
"eu-west-3": "763104351884",
347+
"eu-south-1": "692866216735",
348+
"me-south-1": "217643126080",
349+
"sa-east-1": "763104351884",
350+
"us-east-1": "763104351884",
351+
"us-east-2": "763104351884",
352+
"us-gov-west-1": "442386744353",
353+
"us-iso-east-1": "886529160074",
354+
"us-west-1": "763104351884",
355+
"us-west-2": "763104351884"
356+
},
357+
"repository": "huggingface-pytorch-inference",
358+
"container_version": {"gpu":"cu111-ubuntu18.04", "cpu":"ubuntu18.04" }
290359
},
291360
"tensorflow2.4.1": {
292361
"py_versions": ["py37"],
@@ -317,7 +386,8 @@
317386
"us-west-1": "763104351884",
318387
"us-west-2": "763104351884"
319388
},
320-
"repository": "huggingface-tensorflow-inference"
389+
"repository": "huggingface-tensorflow-inference",
390+
"container_version": {"gpu":"cu110-ubuntu18.04", "cpu":"ubuntu18.04" }
321391
}
322392
}
323393
}

src/sagemaker/image_uris.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def retrieve(
4141
):
4242
"""Retrieves the ECR URI for the Docker image matching the given arguments.
4343
44+
Ideally this function should not be called directly, rather it should be called from the
45+
fit() function inside framework estimator.
46+
4447
Args:
4548
framework (str): The name of the framework or algorithm.
4649
region (str): The AWS region.
@@ -56,7 +59,11 @@ def retrieve(
5659
image_scope (str): The image type, i.e. what it is used for.
5760
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
5861
``image_scope`` is ignored.
59-
container_version (str): the version of docker image
62+
container_version (str): the version of docker image.
63+
Ideally the value of parameter should be created inside the framework.
64+
For custom use, see the list of supported container versions:
65+
https://github.com/aws/deep-learning-containers/blob/master/available_images.md
66+
(default: None).
6067
distribution (dict): A dictionary with information on how to run distributed training
6168
(default: None).
6269
@@ -66,10 +73,12 @@ def retrieve(
6673
Raises:
6774
ValueError: If the combination of arguments specified is not supported.
6875
"""
76+
6977
config = _config_for_framework_and_scope(framework, image_scope, accelerator_type)
7078
original_version = version
7179
version = _validate_version_and_set_if_needed(version, config, framework)
7280
version_config = config["versions"][_version_for_config(version, config)]
81+
7382
if framework == HUGGING_FACE_FRAMEWORK:
7483
if version_config.get("version_aliases"):
7584
full_base_framework_version = version_config["version_aliases"].get(
@@ -81,7 +90,6 @@ def retrieve(
8190

8291
py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)
8392
version_config = version_config.get(py_version) or version_config
84-
8593
registry = _registry_from_region(region, version_config["registries"])
8694
hostname = utils._botocore_resolver().construct_endpoint("ecr", region)["hostname"]
8795

@@ -91,11 +99,16 @@ def retrieve(
9199
instance_type, config.get("processors") or version_config.get("processors")
92100
)
93101

102+
# if container version is available in .json file, utilize that
103+
if version_config.get("container_version"):
104+
container_version = version_config["container_version"][processor]
105+
94106
if framework == HUGGING_FACE_FRAMEWORK:
95107
pt_or_tf_version = (
96108
re.compile("^(pytorch|tensorflow)(.*)$").match(base_framework_version).group(2)
97109
)
98110
tag_prefix = f"{pt_or_tf_version}-transformers{original_version}"
111+
99112
else:
100113
tag_prefix = version_config.get("tag_prefix", version)
101114

@@ -105,6 +118,7 @@ def retrieve(
105118
py_version,
106119
container_version,
107120
)
121+
108122
if _should_auto_select_container_version(instance_type, distribution):
109123
container_versions = {
110124
"tensorflow-2.3-gpu-py37": "cu110-ubuntu18.04-v3",
@@ -120,7 +134,9 @@ def retrieve(
120134
"pytorch-1.6-gpu-py3": "cu110-ubuntu18.04-v3",
121135
"pytorch-1.6.0-gpu-py3": "cu110-ubuntu18.04",
122136
}
137+
123138
key = "-".join([framework, tag])
139+
124140
if key in container_versions:
125141
tag = "-".join([tag, container_versions[key]])
126142

tests/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def _huggingface_base_fm_version(huggingface_version, base_fw, fixture_prefix):
400400
if len(original_version.split(".")) == 2:
401401
base_fw_version = ".".join(base_fw_version.split(".")[:-1])
402402
versions.append(base_fw_version)
403-
return versions
403+
return sorted(versions, reverse=True)
404404

405405

406406
def _generate_huggingface_base_fw_latest_versions(

0 commit comments

Comments
 (0)