Skip to content

Commit 42d5908

Browse files
authored
change: add account number and unit tests for govcloud (#713)
1 parent 2642a15 commit 42d5908

File tree

7 files changed

+115
-71
lines changed

7 files changed

+115
-71
lines changed

src/sagemaker/amazon/amazon_estimator.py

+71-66
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from sagemaker.amazon.common import write_numpy_to_dense_tensor
2424
from sagemaker.estimator import EstimatorBase, _TrainingJob
2525
from sagemaker.session import s3_input
26-
from sagemaker.utils import sagemaker_timestamp
26+
from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix
2727

2828
logger = logging.getLogger(__name__)
2929

@@ -283,86 +283,91 @@ def registry(region_name, algorithm=None):
283283
284284
https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/amazon
285285
"""
286-
if algorithm in [None, "pca", "kmeans", "linear-learner", "factorization-machines", "ntm",
287-
"randomcutforest", "knn", "object2vec", "ipinsights"]:
286+
if algorithm in [None, 'pca', 'kmeans', 'linear-learner', 'factorization-machines', 'ntm',
287+
'randomcutforest', 'knn', 'object2vec', 'ipinsights']:
288288
account_id = {
289-
"us-east-1": "382416733822",
290-
"us-east-2": "404615174143",
291-
"us-west-2": "174872318107",
292-
"eu-west-1": "438346466558",
293-
"eu-central-1": "664544806723",
294-
"ap-northeast-1": "351501993468",
295-
"ap-northeast-2": "835164637446",
296-
"ap-southeast-2": "712309505854",
297-
"us-gov-west-1": "226302683700",
298-
"ap-southeast-1": "475088953585",
299-
"ap-south-1": "991648021394",
300-
"ca-central-1": "469771592824",
301-
"eu-west-2": "644912444149",
302-
"us-west-1": "632365934929",
289+
'us-east-1': '382416733822',
290+
'us-east-2': '404615174143',
291+
'us-west-2': '174872318107',
292+
'eu-west-1': '438346466558',
293+
'eu-central-1': '664544806723',
294+
'ap-northeast-1': '351501993468',
295+
'ap-northeast-2': '835164637446',
296+
'ap-southeast-2': '712309505854',
297+
'us-gov-west-1': '226302683700',
298+
'ap-southeast-1': '475088953585',
299+
'ap-south-1': '991648021394',
300+
'ca-central-1': '469771592824',
301+
'eu-west-2': '644912444149',
302+
'us-west-1': '632365934929',
303+
'us-iso-east-1': '490574956308',
303304
}[region_name]
304-
elif algorithm in ["lda"]:
305+
elif algorithm in ['lda']:
305306
account_id = {
306-
"us-east-1": "766337827248",
307-
"us-east-2": "999911452149",
308-
"us-west-2": "266724342769",
309-
"eu-west-1": "999678624901",
310-
"eu-central-1": "353608530281",
311-
"ap-northeast-1": "258307448986",
312-
"ap-northeast-2": "293181348795",
313-
"ap-southeast-2": "297031611018",
314-
"us-gov-west-1": "226302683700",
315-
"ap-southeast-1": "475088953585",
316-
"ap-south-1": "991648021394",
317-
"ca-central-1": "469771592824",
318-
"eu-west-2": "644912444149",
319-
"us-west-1": "632365934929",
307+
'us-east-1': '766337827248',
308+
'us-east-2': '999911452149',
309+
'us-west-2': '266724342769',
310+
'eu-west-1': '999678624901',
311+
'eu-central-1': '353608530281',
312+
'ap-northeast-1': '258307448986',
313+
'ap-northeast-2': '293181348795',
314+
'ap-southeast-2': '297031611018',
315+
'us-gov-west-1': '226302683700',
316+
'ap-southeast-1': '475088953585',
317+
'ap-south-1': '991648021394',
318+
'ca-central-1': '469771592824',
319+
'eu-west-2': '644912444149',
320+
'us-west-1': '632365934929',
321+
'us-iso-east-1': '490574956308',
320322
}[region_name]
321-
elif algorithm in ["forecasting-deepar"]:
323+
elif algorithm in ['forecasting-deepar']:
322324
account_id = {
323-
"us-east-1": "522234722520",
324-
"us-east-2": "566113047672",
325-
"us-west-2": "156387875391",
326-
"eu-west-1": "224300973850",
327-
"eu-central-1": "495149712605",
328-
"ap-northeast-1": "633353088612",
329-
"ap-northeast-2": "204372634319",
330-
"ap-southeast-2": "514117268639",
331-
"us-gov-west-1": "226302683700",
332-
"ap-southeast-1": "475088953585",
333-
"ap-south-1": "991648021394",
334-
"ca-central-1": "469771592824",
335-
"eu-west-2": "644912444149",
336-
"us-west-1": "632365934929",
325+
'us-east-1': '522234722520',
326+
'us-east-2': '566113047672',
327+
'us-west-2': '156387875391',
328+
'eu-west-1': '224300973850',
329+
'eu-central-1': '495149712605',
330+
'ap-northeast-1': '633353088612',
331+
'ap-northeast-2': '204372634319',
332+
'ap-southeast-2': '514117268639',
333+
'us-gov-west-1': '226302683700',
334+
'ap-southeast-1': '475088953585',
335+
'ap-south-1': '991648021394',
336+
'ca-central-1': '469771592824',
337+
'eu-west-2': '644912444149',
338+
'us-west-1': '632365934929',
339+
'us-iso-east-1': '490574956308',
337340
}[region_name]
338-
elif algorithm in ["xgboost", "seq2seq", "image-classification", "blazingtext",
339-
"object-detection", "semantic-segmentation"]:
341+
elif algorithm in ['xgboost', 'seq2seq', 'image-classification', 'blazingtext',
342+
'object-detection', 'semantic-segmentation']:
340343
account_id = {
341-
"us-east-1": "811284229777",
342-
"us-east-2": "825641698319",
343-
"us-west-2": "433757028032",
344-
"eu-west-1": "685385470294",
345-
"eu-central-1": "813361260812",
346-
"ap-northeast-1": "501404015308",
347-
"ap-northeast-2": "306986355934",
348-
"ap-southeast-2": "544295431143",
349-
"us-gov-west-1": "226302683700",
350-
"ap-southeast-1": "475088953585",
351-
"ap-south-1": "991648021394",
352-
"ca-central-1": "469771592824",
353-
"eu-west-2": "644912444149",
354-
"us-west-1": "632365934929",
344+
'us-east-1': '811284229777',
345+
'us-east-2': '825641698319',
346+
'us-west-2': '433757028032',
347+
'eu-west-1': '685385470294',
348+
'eu-central-1': '813361260812',
349+
'ap-northeast-1': '501404015308',
350+
'ap-northeast-2': '306986355934',
351+
'ap-southeast-2': '544295431143',
352+
'us-gov-west-1': '226302683700',
353+
'ap-southeast-1': '475088953585',
354+
'ap-south-1': '991648021394',
355+
'ca-central-1': '469771592824',
356+
'eu-west-2': '644912444149',
357+
'us-west-1': '632365934929',
358+
'us-iso-east-1': '490574956308',
355359
}[region_name]
356360
elif algorithm in ['image-classification-neo', 'xgboost-neo']:
357361
account_id = {
358362
'us-west-2': '301217895009',
359363
'us-east-1': '785573368785',
360364
'eu-west-1': '802834080501',
361-
'us-east-2': '007439368137'
365+
'us-east-2': '007439368137',
362366
}[region_name]
363367
else:
364-
raise ValueError("Algorithm class:{} doesn't have mapping to account_id with images".format(algorithm))
365-
return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name)
368+
raise ValueError('Algorithm class:{} does not have mapping to account_id with images'.format(algorithm))
369+
370+
return get_ecr_image_uri_prefix(account_id, region_name)
366371

367372

368373
def get_image_uri(region_name, repo_name, repo_version=1):

src/sagemaker/fw_registry.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from __future__ import absolute_import
1414
import logging
1515

16+
from sagemaker.utils import get_ecr_image_uri_prefix
17+
1618
image_registry_map = {
1719
"us-west-1": {
1820
"sparkml-serving": "746614075791",
@@ -69,6 +71,10 @@
6971
"us-gov-west-1": {
7072
"sparkml-serving": "414596584902",
7173
"scikit-learn": "414596584902"
74+
},
75+
"us-iso-east-1": {
76+
"sparkml-serving": "833128469047",
77+
"scikit-learn": "833128469047"
7278
}
7379
}
7480

@@ -80,7 +86,7 @@ def registry(region_name, framework=None):
8086
"""
8187
try:
8288
account_id = image_registry_map[region_name][framework]
83-
return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name)
89+
return get_ecr_image_uri_prefix(account_id, region_name)
8490
except KeyError:
8591
logging.error("The specific image or region does not exist")
8692
raise

src/sagemaker/fw_utils.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import tempfile
2222
from six.moves.urllib.parse import urlparse
2323

24+
from sagemaker.utils import get_ecr_image_uri_prefix
25+
2426
_TAR_SOURCE_FILENAME = 'source.tar.gz'
2527

2628
UploadedCode = namedtuple('UserCode', ['s3_prefix', 'script_name'])
@@ -40,7 +42,8 @@
4042

4143
VALID_PY_VERSIONS = ['py2', 'py3']
4244
VALID_EIA_FRAMEWORKS = ['tensorflow', 'mxnet']
43-
VALID_ACCOUNTS_BY_REGION = {'us-gov-west-1': '246785580436'}
45+
VALID_ACCOUNTS_BY_REGION = {'us-gov-west-1': '246785580436',
46+
'us-iso-east-1': '744548109606'}
4447

4548

4649
def create_image_uri(region, framework, instance_type, framework_version, py_version=None,
@@ -96,8 +99,8 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver
9699
optimized_families=optimized_families):
97100
framework += '-eia'
98101

99-
return "{}.dkr.ecr.{}.amazonaws.com/sagemaker-{}:{}" \
100-
.format(account, region, framework, tag)
102+
return "{}/sagemaker-{}:{}" \
103+
.format(get_ecr_image_uri_prefix(account, region), framework, tag)
101104

102105

103106
def _accelerator_type_valid_for_framework(framework, accelerator_type=None, optimized_families=None):

src/sagemaker/utils.py

+14
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,20 @@ def download_file(bucket_name, path, target, sagemaker_session):
292292
bucket.download_file(path, target)
293293

294294

295+
def get_ecr_image_uri_prefix(account, region):
296+
"""get prefix of ECR image URI
297+
298+
Args:
299+
account (str): AWS account number
300+
region (str): AWS region name
301+
302+
Returns:
303+
(str): URI prefix of ECR image
304+
"""
305+
domain = 'c2s.ic.gov' if region == 'us-iso-east-1' else 'amazonaws.com'
306+
return '{}.dkr.ecr.{}.{}'.format(account, region, domain)
307+
308+
295309
class DeferredError(object):
296310
"""Stores an exception and raises it at a later time if this
297311
object is accessed in any way. Useful to allow soft-dependencies on imports,

tests/unit/test_amazon_estimator.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
# Use PCA as a test implementation of AmazonAlgorithmEstimator
2020
from sagemaker.amazon.pca import PCA
21-
from sagemaker.amazon.amazon_estimator import upload_numpy_to_s3_shards, _build_shards, registry
21+
from sagemaker.amazon.amazon_estimator import upload_numpy_to_s3_shards, _build_shards, registry, get_image_uri
2222

2323
COMMON_ARGS = {'role': 'myrole', 'train_instance_count': 1, 'train_instance_type': 'ml.c4.xlarge'}
2424

@@ -61,6 +61,14 @@ def sagemaker_session():
6161
return sms
6262

6363

64+
def test_gov_ecr_uri():
65+
assert get_image_uri('us-gov-west-1', 'kmeans', 'latest') == \
66+
'226302683700.dkr.ecr.us-gov-west-1.amazonaws.com/kmeans:latest'
67+
68+
assert get_image_uri('us-iso-east-1', 'kmeans', 'latest') == \
69+
'490574956308.dkr.ecr.us-iso-east-1.c2s.ic.gov/kmeans:latest'
70+
71+
6472
def test_init(sagemaker_session):
6573
pca = PCA(num_components=55, sagemaker_session=sagemaker_session, **COMMON_ARGS)
6674
assert pca.num_components == 55

tests/unit/test_fw_registry.py

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def test_registry_sparkml_serving():
3636
assert registry('eu-central-1', 'sparkml-serving') == "492215442770.dkr.ecr.eu-central-1.amazonaws.com"
3737
assert registry('ca-central-1', 'sparkml-serving') == "341280168497.dkr.ecr.ca-central-1.amazonaws.com"
3838
assert registry('us-gov-west-1', 'sparkml-serving') == "414596584902.dkr.ecr.us-gov-west-1.amazonaws.com"
39+
assert registry('us-iso-east-1', 'sparkml-serving') == "833128469047.dkr.ecr.us-iso-east-1.c2s.ic.gov"
3940

4041

4142
def test_registry_sklearn():
@@ -57,6 +58,7 @@ def test_registry_sklearn():
5758
assert registry('eu-central-1', scikit_learn_framework_name) == "492215442770.dkr.ecr.eu-central-1.amazonaws.com"
5859
assert registry('ca-central-1', scikit_learn_framework_name) == "341280168497.dkr.ecr.ca-central-1.amazonaws.com"
5960
assert registry('us-gov-west-1', scikit_learn_framework_name) == "414596584902.dkr.ecr.us-gov-west-1.amazonaws.com"
61+
assert registry('us-iso-east-1', scikit_learn_framework_name) == "833128469047.dkr.ecr.us-iso-east-1.c2s.ic.gov"
6062

6163

6264
def test_default_sklearn_image_uri():

tests/unit/test_fw_utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ def test_create_image_uri_cpu():
6161
image_uri = fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'local', '1.0rc', 'py2', '23')
6262
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2'
6363

64+
image_uri = fw_utils.create_image_uri('us-gov-west-1', MOCK_FRAMEWORK, 'ml.c4.large', '1.0rc', 'py2', '23')
65+
assert image_uri == '246785580436.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2'
66+
67+
image_uri = fw_utils.create_image_uri('us-iso-east-1', MOCK_FRAMEWORK, 'ml.c4.large', '1.0rc', 'py2', '23')
68+
assert image_uri == '744548109606.dkr.ecr.us-iso-east-1.c2s.ic.gov/sagemaker-mlfw:1.0rc-cpu-py2'
69+
6470

6571
def test_create_image_uri_no_python():
6672
image_uri = fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'ml.c4.large', '1.0rc', account='23')

0 commit comments

Comments
 (0)