|
54 | 54 | VALID_EIA_FRAMEWORKS = ["tensorflow", "tensorflow-serving", "mxnet", "mxnet-serving"]
|
55 | 55 | VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "246785580436", "us-iso-east-1": "744548109606"}
|
56 | 56 |
|
| 57 | +MERGED_FRAMEWORKS_REPO_MAP = { |
| 58 | + "tensorflow-scriptmode": "tensorflow-training", |
| 59 | + "mxnet": "mxnet-training", |
| 60 | + "tensorflow-serving": "tensorflow-inference", |
| 61 | + "mxnet-serving": "mxnet-inference", |
| 62 | +} |
| 63 | + |
| 64 | +MERGED_FRAMEWORKS_LOWEST_VERSIONS = { |
| 65 | + "tensorflow-scriptmode": [1, 13, 1], |
| 66 | + "mxnet": [1, 4, 1], |
| 67 | + "tensorflow-serving": [1, 13, 0], |
| 68 | + "mxnet-serving": [1, 4, 1], |
| 69 | +} |
| 70 | + |
| 71 | + |
| 72 | +def is_version_equal_or_higher(lowest_version, framework_version): |
| 73 | + """Determine whether the ``framework_version`` is equal to or higher than ``lowest_version`` |
| 74 | +
|
| 75 | + Args: |
| 76 | + lowest_version (List[int]): lowest version represented in an integer list |
| 77 | + framework_version (str): framework version string |
| 78 | +
|
| 79 | + Returns: |
| 80 | + bool: Whether or not framework_version is equal to or higher than lowest_version |
| 81 | + """ |
| 82 | + version_list = [int(s) for s in framework_version.split(".")] |
| 83 | + return version_list >= lowest_version[0 : len(version_list)] |
| 84 | + |
| 85 | + |
| 86 | +def _is_merged_versions(framework, framework_version): |
| 87 | + lowest_version_list = MERGED_FRAMEWORKS_LOWEST_VERSIONS.get(framework) |
| 88 | + if lowest_version_list: |
| 89 | + return is_version_equal_or_higher(lowest_version_list, framework_version) |
| 90 | + else: |
| 91 | + return False |
| 92 | + |
| 93 | + |
| 94 | +def _using_merged_images(region, framework, py_version, accelerator_type, framework_version): |
| 95 | + is_gov_region = region in VALID_ACCOUNTS_BY_REGION |
| 96 | + is_py3 = py_version == "py3" or py_version is None |
| 97 | + is_merged_versions = _is_merged_versions(framework, framework_version) |
| 98 | + return (not is_gov_region) and is_merged_versions and is_py3 and accelerator_type is None |
| 99 | + |
| 100 | + |
| 101 | +def _registry_id(region, framework, py_version, account, accelerator_type, framework_version): |
| 102 | + if _using_merged_images(region, framework, py_version, accelerator_type, framework_version): |
| 103 | + return "763104351884" |
| 104 | + else: |
| 105 | + return VALID_ACCOUNTS_BY_REGION.get(region, account) |
| 106 | + |
57 | 107 |
|
58 | 108 | def create_image_uri(
|
59 | 109 | region,
|
@@ -86,8 +136,15 @@ def create_image_uri(
|
86 | 136 | if py_version and py_version not in VALID_PY_VERSIONS:
|
87 | 137 | raise ValueError("invalid py_version argument: {}".format(py_version))
|
88 | 138 |
|
89 |
| - # Handle Account Number for Gov Cloud |
90 |
| - account = VALID_ACCOUNTS_BY_REGION.get(region, account) |
| 139 | + # Handle Account Number for Gov Cloud and frameworks with DLC merged images |
| 140 | + account = _registry_id( |
| 141 | + region=region, |
| 142 | + framework=framework, |
| 143 | + py_version=py_version, |
| 144 | + account=account, |
| 145 | + accelerator_type=accelerator_type, |
| 146 | + framework_version=framework_version, |
| 147 | + ) |
91 | 148 |
|
92 | 149 | # Handle Local Mode
|
93 | 150 | if instance_type.startswith("local"):
|
@@ -121,7 +178,14 @@ def create_image_uri(
|
121 | 178 | ):
|
122 | 179 | framework += "-eia"
|
123 | 180 |
|
124 |
| - return "{}/sagemaker-{}:{}".format(get_ecr_image_uri_prefix(account, region), framework, tag) |
| 181 | + if _using_merged_images(region, framework, py_version, accelerator_type, framework_version): |
| 182 | + return "{}/{}:{}".format( |
| 183 | + get_ecr_image_uri_prefix(account, region), MERGED_FRAMEWORKS_REPO_MAP[framework], tag |
| 184 | + ) |
| 185 | + else: |
| 186 | + return "{}/sagemaker-{}:{}".format( |
| 187 | + get_ecr_image_uri_prefix(account, region), framework, tag |
| 188 | + ) |
125 | 189 |
|
126 | 190 |
|
127 | 191 | def _accelerator_type_valid_for_framework(
|
@@ -264,7 +328,7 @@ def framework_name_from_image(image_name):
|
264 | 328 | # extract framework, python version and image tag
|
265 | 329 | # We must support both the legacy and current image name format.
|
266 | 330 | name_pattern = re.compile(
|
267 |
| - r"^sagemaker(?:-rl)?-(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501 |
| 331 | + r"^(?:sagemaker(?:-rl)?-)?(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode|training)?:(.*)-(.*?)-(py2|py3)$" # noqa: E501 |
268 | 332 | )
|
269 | 333 | legacy_name_pattern = re.compile(r"^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$")
|
270 | 334 |
|
|
0 commit comments