Skip to content

Commit b8d5470

Browse files
authored
change: use DLC images for GovCloud (#1329)
1 parent 5c4d603 commit b8d5470

File tree

2 files changed

+61
-25
lines changed

2 files changed

+61
-25
lines changed

src/sagemaker/fw_utils.py

+3-22
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
]
6363
PY2_RESTRICTED_EIA_FRAMEWORKS = ["pytorch-serving"]
6464
VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "246785580436", "us-iso-east-1": "744548109606"}
65-
ASIMOV_VALID_ACCOUNTS_BY_REGION = {"us-iso-east-1": "886529160074"}
65+
ASIMOV_VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074"}
6666
OPT_IN_ACCOUNTS_BY_REGION = {"ap-east-1": "057415533634", "me-south-1": "724002660598"}
6767
ASIMOV_OPT_IN_ACCOUNTS_BY_REGION = {"ap-east-1": "871362719292", "me-south-1": "217643126080"}
6868
DEFAULT_ACCOUNT = "520713654638"
@@ -133,25 +133,6 @@ def _is_dlc_version(framework, framework_version, py_version):
133133
return False
134134

135135

136-
def _use_dlc_image(region, framework, py_version, framework_version):
137-
"""Return if the DLC image should be used for the given framework,
138-
framework version, Python version, and region.
139-
140-
Args:
141-
region (str): The AWS region.
142-
framework (str): The framework name, e.g. "tensorflow-scriptmode".
143-
py_version (str): The Python version, e.g. "py3".
144-
framework_version (str): The framework version.
145-
146-
Returns:
147-
bool: Whether or not to use the corresponding DLC image.
148-
"""
149-
is_gov_region = region in VALID_ACCOUNTS_BY_REGION
150-
is_dlc_version = _is_dlc_version(framework, framework_version, py_version)
151-
152-
return ((not is_gov_region) or region in ASIMOV_VALID_ACCOUNTS_BY_REGION) and is_dlc_version
153-
154-
155136
def _registry_id(region, framework, py_version, account, framework_version):
156137
"""Return the Amazon ECR registry number (or AWS account ID) for
157138
the given framework, framework version, Python version, and region.
@@ -168,7 +149,7 @@ def _registry_id(region, framework, py_version, account, framework_version):
168149
specific one for the framework, framework version, Python version,
169150
and region, then ``account`` is returned.
170151
"""
171-
if _use_dlc_image(region, framework, py_version, framework_version):
152+
if _is_dlc_version(framework, framework_version, py_version):
172153
if region in ASIMOV_OPT_IN_ACCOUNTS_BY_REGION:
173154
return ASIMOV_OPT_IN_ACCOUNTS_BY_REGION.get(region)
174155
if region in ASIMOV_VALID_ACCOUNTS_BY_REGION:
@@ -253,7 +234,7 @@ def create_image_uri(
253234
else:
254235
device_type = "cpu"
255236

256-
use_dlc_image = _use_dlc_image(region, framework, py_version, framework_version)
237+
use_dlc_image = _is_dlc_version(framework, framework_version, py_version)
257238

258239
if not py_version or (use_dlc_image and framework == "tensorflow-serving-eia"):
259240
tag = "{}-{}".format(framework_version, device_type)

tests/unit/test_fw_utils.py

+58-3
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def test_create_image_uri_hkg_override_account():
363363
assert {image_uri == "fake.dkr.ecr.ap-east-1.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3"}
364364

365365

366-
def test_create_image_uri_merged():
366+
def test_create_dlc_image_uri():
367367
image_uri = fw_utils.create_image_uri(
368368
"us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.14", "py3"
369369
)
@@ -419,7 +419,7 @@ def test_create_image_uri_merged():
419419
)
420420

421421

422-
def test_create_image_uri_merged_py2():
422+
def test_create_dlc_image_uri_py2():
423423
image_uri = fw_utils.create_image_uri(
424424
"us-west-2", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.13.1", "py2"
425425
)
@@ -450,7 +450,7 @@ def test_create_image_uri_merged_py2():
450450
)
451451

452452

453-
def test_create_image_uri_merged_gov_regions():
453+
def test_create_dlc_image_uri_iso_east_1():
454454
image_uri = fw_utils.create_image_uri(
455455
"us-iso-east-1", "tensorflow-scriptmode", "ml.m4.xlarge", "1.13.1", "py3"
456456
)
@@ -493,6 +493,61 @@ def test_create_image_uri_merged_gov_regions():
493493
)
494494

495495

496+
def test_create_dlc_image_uri_gov_west_1():
497+
image_uri = fw_utils.create_image_uri(
498+
"us-gov-west-1", "tensorflow-scriptmode", "ml.m4.xlarge", "1.13.1", "py3"
499+
)
500+
assert (
501+
image_uri
502+
== "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/tensorflow-training:1.13.1-cpu-py3"
503+
)
504+
505+
image_uri = fw_utils.create_image_uri(
506+
"us-gov-west-1", "tensorflow-scriptmode", "ml.p3.2xlarge", "1.14", "py2"
507+
)
508+
assert (
509+
image_uri
510+
== "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/tensorflow-training:1.14-gpu-py2"
511+
)
512+
513+
image_uri = fw_utils.create_image_uri(
514+
"us-gov-west-1", "tensorflow-serving", "ml.m4.xlarge", "1.13.0"
515+
)
516+
assert (
517+
image_uri
518+
== "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/tensorflow-inference:1.13.0-cpu"
519+
)
520+
521+
image_uri = fw_utils.create_image_uri("us-gov-west-1", "mxnet", "ml.p3.2xlarge", "1.4.1", "py3")
522+
assert (
523+
image_uri == "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/mxnet-training:1.4.1-gpu-py3"
524+
)
525+
526+
image_uri = fw_utils.create_image_uri(
527+
"us-gov-west-1", "mxnet-serving", "ml.c4.2xlarge", "1.4.1", "py3"
528+
)
529+
assert (
530+
image_uri
531+
== "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/mxnet-inference:1.4.1-cpu-py3"
532+
)
533+
534+
image_uri = fw_utils.create_image_uri(
535+
"us-gov-west-1", "pytorch", "ml.p3.2xlarge", "1.2.0", "py3"
536+
)
537+
assert (
538+
image_uri
539+
== "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-training:1.2.0-gpu-py3"
540+
)
541+
542+
image_uri = fw_utils.create_image_uri(
543+
"us-gov-west-1", "pytorch-serving", "ml.c4.2xlarge", "1.2.0", "py3"
544+
)
545+
assert (
546+
image_uri
547+
== "442386744353.dkr.ecr.us-gov-west-1.amazonaws.com/pytorch-inference:1.2.0-cpu-py3"
548+
)
549+
550+
496551
def test_create_image_uri_pytorch(pytorch_version):
497552
image_uri = fw_utils.create_image_uri(
498553
"us-west-2", "pytorch", "ml.p3.2xlarge", pytorch_version, "py3"

0 commit comments

Comments
 (0)