|
59 | 59 | "local_gpu",
|
60 | 60 | )
|
61 | 61 | SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = {
|
62 |
| - "tensorflow": ["2.3", "2.3.1", "2.3.2", "2.4", "2.4.1"], |
63 |
| - "pytorch": ["1.6", "1.6.0", "1.7", "1.7.1", "1.8", "1.8.0", "1.8.1"], |
| 62 | + "tensorflow": ["2.3", "2.3.1", "2.3.2", "2.4", "2.4.1", "2.4.3", "2.5", "2.5.0", "2.5.1"], |
| 63 | + "pytorch": ["1.6", "1.6.0", "1.7", "1.7.1", "1.8", "1.8.0", "1.8.1", "1.9", "1.9.0"], |
64 | 64 | }
|
65 | 65 | SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
|
66 | 66 |
|
@@ -298,7 +298,7 @@ def framework_name_from_image(image_uri):
|
298 | 298 | (tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost
|
299 | 299 | |huggingface-tensorflow|huggingface-pytorch)(?:-)?
|
300 | 300 | (scriptmode|training)?
|
301 |
| - :(.*)-(.*?)-(py2|py3[67]?)(?:.*)$""", |
| 301 | + :(.*)-(.*?)-(py2|py3\d*)(?:.*)$""", |
302 | 302 | re.VERBOSE,
|
303 | 303 | )
|
304 | 304 | name_match = name_pattern.match(sagemaker_match.group(9))
|
@@ -329,7 +329,7 @@ def framework_version_from_tag(image_tag):
|
329 | 329 | Returns:
|
330 | 330 | str: The framework version.
|
331 | 331 | """
|
332 |
| - tag_pattern = re.compile("^(.*)-(cpu|gpu)-(py2|py3[67]?)$") |
| 332 | + tag_pattern = re.compile(r"^(.*)-(cpu|gpu)-(py2|py3\d*)$") |
333 | 333 | tag_match = tag_pattern.match(image_tag)
|
334 | 334 | return None if tag_match is None else tag_match.group(1)
|
335 | 335 |
|
@@ -533,7 +533,7 @@ def _validate_smdataparallel_args(
|
533 | 533 | if "py3" not in py_version:
|
534 | 534 | err_msg += (
|
535 | 535 | f"Provided py_version {py_version} is not supported by smdataparallel.\n"
|
536 |
| - "Please specify py_version=py3" |
| 536 | + "Please specify py_version>=py3" |
537 | 537 | )
|
538 | 538 |
|
539 | 539 | if err_msg:
|
|
0 commit comments