From 4fd9d6ee0d8da5ec29610b2471ce7cdba1af843a Mon Sep 17 00:00:00 2001 From: Yu Date: Fri, 22 Mar 2019 16:12:42 -0700 Subject: [PATCH 1/2] fix: fix ECR URI validation for gov cloud --- src/sagemaker/fw_utils.py | 6 +++--- src/sagemaker/local/image.py | 2 +- tests/unit/test_fw_utils.py | 5 +++++ tests/unit/test_image.py | 7 +++++-- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 629b2386f4..3e1f75ab7d 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -223,7 +223,7 @@ def framework_name_from_image(image_name): str: The image tag str: If the image is script mode """ - sagemaker_pattern = re.compile(r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)amazonaws.com(/)(.*:.*)$') + sagemaker_pattern = re.compile(r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$') sagemaker_match = sagemaker_pattern.match(image_name) if sagemaker_match is None: return None, None, None, None @@ -235,8 +235,8 @@ def framework_name_from_image(image_name): legacy_name_pattern = re.compile( r'^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$') - name_match = name_pattern.match(sagemaker_match.group(8)) - legacy_match = legacy_name_pattern.match(sagemaker_match.group(8)) + name_match = name_pattern.match(sagemaker_match.group(9)) + legacy_match = legacy_name_pattern.match(sagemaker_match.group(9)) if name_match is not None: fw, scriptmode, ver, device, py = name_match.group(1), name_match.group(2), name_match.group(3),\ diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index b0e6800552..7bbe8b623a 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -688,7 +688,7 @@ def _write_json_file(filename, content): def _ecr_login_if_needed(boto_session, image): # Only ECR images need login - if not ('dkr.ecr' in image and 'amazonaws.com' in image): + if not ('dkr.ecr' in image and ('amazonaws.com' in image or 'c2s.ic.gov' in image)): return False # do we have the image? diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 61c2f28869..579b105f9c 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -380,6 +380,11 @@ def test_framework_name_from_image_mxnet(): assert ('mxnet', 'py3', '1.1-gpu-py3', None) == fw_utils.framework_name_from_image(image_name) +def test_framework_name_from_image_mxnet_in_gov(): + image_name = '123.dkr.ecr.region-name.c2s.ic.gov/sagemaker-mxnet:1.1-gpu-py3' + assert ('mxnet', 'py3', '1.1-gpu-py3', None) == fw_utils.framework_name_from_image(image_name) + + def test_framework_name_from_image_tf(): image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.6-cpu-py2' assert ('tensorflow', 'py2', '1.6-cpu-py2', None) == fw_utils.framework_name_from_image(image_name) diff --git a/tests/unit/test_image.py b/tests/unit/test_image.py index 4863f06318..d5bfdbcbcf 100644 --- a/tests/unit/test_image.py +++ b/tests/unit/test_image.py @@ -622,10 +622,13 @@ def test_ecr_login_non_ecr(): @patch('sagemaker.local.image._check_output', return_value='123451324') -def test_ecr_login_image_exists(_check_output): +@pytest.mark.parametrize('image', [ + '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-have:1.0', + '520713654638.dkr.ecr.us-east-1.c2s.ic.gov.com/image-i-have:1.0' +]) +def test_ecr_login_image_exists(_check_output, image): session_mock = Mock() - image = '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-have:1.0' result = sagemaker.local.image._ecr_login_if_needed(session_mock, image) session_mock.assert_not_called() From 561723ddef4ab994a8a48a01c6c9c7ff365a5f3e Mon Sep 17 00:00:00 2001 From: Yu Date: Fri, 22 Mar 2019 16:41:40 -0700 Subject: [PATCH 2/2] Address comments --- src/sagemaker/fw_utils.py | 4 ++-- src/sagemaker/local/image.py | 5 ++++- src/sagemaker/utils.py | 3 +++ tests/unit/test_image.py | 2 +- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 3e1f75ab7d..2cb530a209 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -21,7 +21,7 @@ import tempfile from six.moves.urllib.parse import urlparse -from sagemaker.utils import get_ecr_image_uri_prefix +from sagemaker.utils import get_ecr_image_uri_prefix, ECR_URI_PATTERN _TAR_SOURCE_FILENAME = 'source.tar.gz' @@ -223,7 +223,7 @@ def framework_name_from_image(image_name): str: The image tag str: If the image is script mode """ - sagemaker_pattern = re.compile(r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$') + sagemaker_pattern = re.compile(ECR_URI_PATTERN) sagemaker_match = sagemaker_pattern.match(image_name) if sagemaker_match is None: return None, None, None, None diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index 7bbe8b623a..203a73eebe 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -19,6 +19,7 @@ import os import platform import random +import re import shlex import shutil import string @@ -688,7 +689,9 @@ def _write_json_file(filename, content): def _ecr_login_if_needed(boto_session, image): # Only ECR images need login - if not ('dkr.ecr' in image and ('amazonaws.com' in image or 'c2s.ic.gov' in image)): + sagemaker_pattern = re.compile(sagemaker.utils.ECR_URI_PATTERN) + sagemaker_match = sagemaker_pattern.match(image) + if not sagemaker_match: return False # do we have the image? diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index f2e6d86b94..417619e38b 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -27,6 +27,9 @@ import six +ECR_URI_PATTERN = r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$' + + # Use the base name of the image as the job name if the user doesn't give us one def name_from_image(image): """Create a training job name based on the image name and a timestamp. diff --git a/tests/unit/test_image.py b/tests/unit/test_image.py index d5bfdbcbcf..bec6610686 100644 --- a/tests/unit/test_image.py +++ b/tests/unit/test_image.py @@ -624,7 +624,7 @@ def test_ecr_login_non_ecr(): @patch('sagemaker.local.image._check_output', return_value='123451324') @pytest.mark.parametrize('image', [ '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-have:1.0', - '520713654638.dkr.ecr.us-east-1.c2s.ic.gov.com/image-i-have:1.0' + '520713654638.dkr.ecr.us-iso-east-1.c2s.ic.gov/image-i-have:1.0' ]) def test_ecr_login_image_exists(_check_output, image): session_mock = Mock()