Skip to content

Commit e40638a

Browse files
saimidumufaddal-rohawala
authored andcommitted
feature: Add support for TF 2.8 (aws#3000)
* feature: Add support for TF 2.7 and TF 2.8 * Correct TF 2.7 patch version used in tests * Correct TF 2.7 patch version for SMDDP TF versions * Only add entries for TF 2.8 * Fix test options for TF 2.8 for fw_utils * Update tensorflow.json Co-authored-by: Mufaddal Rohawala <[email protected]>
1 parent ea86e37 commit e40638a

File tree

4 files changed

+81
-4
lines changed

4 files changed

+81
-4
lines changed

src/sagemaker/fw_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@
7575
"2.6.0",
7676
"2.6.2",
7777
"2.6.3",
78+
"2.8",
79+
"2.8.0",
7880
],
7981
"pytorch": [
8082
"1.6",

src/sagemaker/image_uri_config/tensorflow.json

+67-2
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,8 @@
279279
"2.3": "2.3.2",
280280
"2.4": "2.4.3",
281281
"2.5": "2.5.1",
282-
"2.6": "2.6.3"
282+
"2.6": "2.6.3",
283+
"2.8": "2.8.0"
283284
},
284285
"versions": {
285286
"1.10.0": {
@@ -1373,6 +1374,36 @@
13731374
"us-west-2": "763104351884"
13741375
},
13751376
"repository": "tensorflow-inference"
1377+
},
1378+
"2.8.0": {
1379+
"registries": {
1380+
"af-south-1": "626614931356",
1381+
"ap-east-1": "871362719292",
1382+
"ap-northeast-1": "763104351884",
1383+
"ap-northeast-2": "763104351884",
1384+
"ap-northeast-3": "364406365360",
1385+
"ap-south-1": "763104351884",
1386+
"ap-southeast-1": "763104351884",
1387+
"ap-southeast-2": "763104351884",
1388+
"ca-central-1": "763104351884",
1389+
"cn-north-1": "727897471807",
1390+
"cn-northwest-1": "727897471807",
1391+
"eu-central-1": "763104351884",
1392+
"eu-north-1": "763104351884",
1393+
"eu-south-1": "692866216735",
1394+
"eu-west-1": "763104351884",
1395+
"eu-west-2": "763104351884",
1396+
"eu-west-3": "763104351884",
1397+
"me-south-1": "217643126080",
1398+
"sa-east-1": "763104351884",
1399+
"us-east-1": "763104351884",
1400+
"us-east-2": "763104351884",
1401+
"us-gov-west-1": "442386744353",
1402+
"us-iso-east-1": "886529160074",
1403+
"us-west-1": "763104351884",
1404+
"us-west-2": "763104351884"
1405+
},
1406+
"repository": "tensorflow-inference"
13761407
}
13771408
}
13781409
},
@@ -1400,7 +1431,8 @@
14001431
"2.3": "2.3.2",
14011432
"2.4": "2.4.3",
14021433
"2.5": "2.5.1",
1403-
"2.6": "2.6.3"
1434+
"2.6": "2.6.3",
1435+
"2.8": "2.8.0"
14041436
},
14051437
"versions": {
14061438
"1.10.0": {
@@ -2692,6 +2724,39 @@
26922724
"us-west-2": "763104351884"
26932725
},
26942726
"repository": "tensorflow-training"
2727+
},
2728+
"2.8.0": {
2729+
"py_versions": [
2730+
"py39"
2731+
],
2732+
"registries": {
2733+
"af-south-1": "626614931356",
2734+
"ap-east-1": "871362719292",
2735+
"ap-northeast-1": "763104351884",
2736+
"ap-northeast-2": "763104351884",
2737+
"ap-northeast-3": "364406365360",
2738+
"ap-south-1": "763104351884",
2739+
"ap-southeast-1": "763104351884",
2740+
"ap-southeast-2": "763104351884",
2741+
"ca-central-1": "763104351884",
2742+
"cn-north-1": "727897471807",
2743+
"cn-northwest-1": "727897471807",
2744+
"eu-central-1": "763104351884",
2745+
"eu-north-1": "763104351884",
2746+
"eu-south-1": "692866216735",
2747+
"eu-west-1": "763104351884",
2748+
"eu-west-2": "763104351884",
2749+
"eu-west-3": "763104351884",
2750+
"me-south-1": "217643126080",
2751+
"sa-east-1": "763104351884",
2752+
"us-east-1": "763104351884",
2753+
"us-east-2": "763104351884",
2754+
"us-gov-west-1": "442386744353",
2755+
"us-iso-east-1": "886529160074",
2756+
"us-west-1": "763104351884",
2757+
"us-west-2": "763104351884"
2758+
},
2759+
"repository": "tensorflow-training"
26952760
}
26962761
}
26972762
}

tests/conftest.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,9 @@ def _tf_py_version(tf_version, request):
352352
return request.param
353353
if Version("2.2") <= version < Version("2.6"):
354354
return "py37"
355-
return "py38"
355+
if Version("2.6") <= version < Version("2.8"):
356+
return "py38"
357+
return "py39"
356358

357359

358360
@pytest.fixture(scope="module")
@@ -384,7 +386,9 @@ def tf_full_py_version(tf_full_version):
384386
return "py3"
385387
if version < Version("2.6"):
386388
return "py37"
387-
return "py38"
389+
if version < Version("2.8"):
390+
return "py38"
391+
return "py39"
388392

389393

390394
@pytest.fixture(scope="session")

tests/unit/test_fw_utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,7 @@ def test_validate_smdataparallel_args_not_raises():
678678
("ml.p3.16xlarge", "tensorflow", "2.3.2", "py37", smdataparallel_enabled),
679679
("ml.p3.16xlarge", "tensorflow", "2.3", "py37", smdataparallel_enabled),
680680
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled),
681+
("ml.p3.16xlarge", "tensorflow", "2.4.3", "py37", smdataparallel_enabled),
681682
("ml.p3.16xlarge", "tensorflow", "2.4", "py37", smdataparallel_enabled),
682683
("ml.p3.16xlarge", "tensorflow", "2.5.0", "py37", smdataparallel_enabled),
683684
("ml.p3.16xlarge", "tensorflow", "2.5.1", "py37", smdataparallel_enabled),
@@ -686,6 +687,8 @@ def test_validate_smdataparallel_args_not_raises():
686687
("ml.p3.16xlarge", "tensorflow", "2.6.2", "py38", smdataparallel_enabled),
687688
("ml.p3.16xlarge", "tensorflow", "2.6.3", "py38", smdataparallel_enabled),
688689
("ml.p3.16xlarge", "tensorflow", "2.6", "py38", smdataparallel_enabled),
690+
("ml.p3.16xlarge", "tensorflow", "2.8.0", "py39", smdataparallel_enabled),
691+
("ml.p3.16xlarge", "tensorflow", "2.8", "py39", smdataparallel_enabled),
689692
("ml.p3.16xlarge", "pytorch", "1.6.0", "py3", smdataparallel_enabled),
690693
("ml.p3.16xlarge", "pytorch", "1.6", "py3", smdataparallel_enabled),
691694
("ml.p3.16xlarge", "pytorch", "1.7.1", "py3", smdataparallel_enabled),
@@ -698,10 +701,13 @@ def test_validate_smdataparallel_args_not_raises():
698701
("ml.p3.16xlarge", "pytorch", "1.10", "py38", smdataparallel_enabled),
699702
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi),
700703
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled_custom_mpi),
704+
("ml.p3.16xlarge", "tensorflow", "2.4.3", "py3", smdataparallel_enabled_custom_mpi),
705+
("ml.p3.16xlarge", "tensorflow", "2.4.3", "py37", smdataparallel_enabled_custom_mpi),
701706
("ml.p3.16xlarge", "tensorflow", "2.5.1", "py37", smdataparallel_enabled_custom_mpi),
702707
("ml.p3.16xlarge", "tensorflow", "2.6.0", "py38", smdataparallel_enabled_custom_mpi),
703708
("ml.p3.16xlarge", "tensorflow", "2.6.2", "py38", smdataparallel_enabled_custom_mpi),
704709
("ml.p3.16xlarge", "tensorflow", "2.6.3", "py38", smdataparallel_enabled_custom_mpi),
710+
("ml.p3.16xlarge", "tensorflow", "2.8.0", "py39", smdataparallel_enabled_custom_mpi),
705711
("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled_custom_mpi),
706712
("ml.p3.16xlarge", "pytorch", "1.9.1", "py38", smdataparallel_enabled_custom_mpi),
707713
]

0 commit comments

Comments
 (0)