Skip to content

Commit ca39a2b

Browse files
authored
fix: fix ECR URI validation (#719)
1 parent cd2b23d commit ca39a2b

File tree

5 files changed

+21
-7
lines changed

5 files changed

+21
-7
lines changed

src/sagemaker/fw_utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import tempfile
2222
from six.moves.urllib.parse import urlparse
2323

24-
from sagemaker.utils import get_ecr_image_uri_prefix
24+
from sagemaker.utils import get_ecr_image_uri_prefix, ECR_URI_PATTERN
2525

2626
_TAR_SOURCE_FILENAME = 'source.tar.gz'
2727

@@ -223,7 +223,7 @@ def framework_name_from_image(image_name):
223223
str: The image tag
224224
str: If the image is script mode
225225
"""
226-
sagemaker_pattern = re.compile(r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)amazonaws.com(/)(.*:.*)$')
226+
sagemaker_pattern = re.compile(ECR_URI_PATTERN)
227227
sagemaker_match = sagemaker_pattern.match(image_name)
228228
if sagemaker_match is None:
229229
return None, None, None, None
@@ -235,8 +235,8 @@ def framework_name_from_image(image_name):
235235
legacy_name_pattern = re.compile(
236236
r'^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$')
237237

238-
name_match = name_pattern.match(sagemaker_match.group(8))
239-
legacy_match = legacy_name_pattern.match(sagemaker_match.group(8))
238+
name_match = name_pattern.match(sagemaker_match.group(9))
239+
legacy_match = legacy_name_pattern.match(sagemaker_match.group(9))
240240

241241
if name_match is not None:
242242
fw, scriptmode, ver, device, py = name_match.group(1), name_match.group(2), name_match.group(3),\

src/sagemaker/local/image.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import os
2020
import platform
2121
import random
22+
import re
2223
import shlex
2324
import shutil
2425
import string
@@ -688,7 +689,9 @@ def _write_json_file(filename, content):
688689

689690
def _ecr_login_if_needed(boto_session, image):
690691
# Only ECR images need login
691-
if not ('dkr.ecr' in image and 'amazonaws.com' in image):
692+
sagemaker_pattern = re.compile(sagemaker.utils.ECR_URI_PATTERN)
693+
sagemaker_match = sagemaker_pattern.match(image)
694+
if not sagemaker_match:
692695
return False
693696

694697
# do we have the image?

src/sagemaker/utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
import six
2828

2929

30+
ECR_URI_PATTERN = r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$'
31+
32+
3033
# Use the base name of the image as the job name if the user doesn't give us one
3134
def name_from_image(image):
3235
"""Create a training job name based on the image name and a timestamp.

tests/unit/test_fw_utils.py

+5
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,11 @@ def test_framework_name_from_image_mxnet():
380380
assert ('mxnet', 'py3', '1.1-gpu-py3', None) == fw_utils.framework_name_from_image(image_name)
381381

382382

383+
def test_framework_name_from_image_mxnet_in_gov():
384+
image_name = '123.dkr.ecr.region-name.c2s.ic.gov/sagemaker-mxnet:1.1-gpu-py3'
385+
assert ('mxnet', 'py3', '1.1-gpu-py3', None) == fw_utils.framework_name_from_image(image_name)
386+
387+
383388
def test_framework_name_from_image_tf():
384389
image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.6-cpu-py2'
385390
assert ('tensorflow', 'py2', '1.6-cpu-py2', None) == fw_utils.framework_name_from_image(image_name)

tests/unit/test_image.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -622,10 +622,13 @@ def test_ecr_login_non_ecr():
622622

623623

624624
@patch('sagemaker.local.image._check_output', return_value='123451324')
625-
def test_ecr_login_image_exists(_check_output):
625+
@pytest.mark.parametrize('image', [
626+
'520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-have:1.0',
627+
'520713654638.dkr.ecr.us-iso-east-1.c2s.ic.gov/image-i-have:1.0'
628+
])
629+
def test_ecr_login_image_exists(_check_output, image):
626630
session_mock = Mock()
627631

628-
image = '520713654638.dkr.ecr.us-east-1.amazonaws.com/image-i-have:1.0'
629632
result = sagemaker.local.image._ecr_login_if_needed(session_mock, image)
630633

631634
session_mock.assert_not_called()

0 commit comments

Comments
 (0)