Skip to content

change: add account number and unit tests for govcloud #713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 71 additions & 66 deletions src/sagemaker/amazon/amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sagemaker.amazon.common import write_numpy_to_dense_tensor
from sagemaker.estimator import EstimatorBase, _TrainingJob
from sagemaker.session import s3_input
from sagemaker.utils import sagemaker_timestamp
from sagemaker.utils import sagemaker_timestamp, get_ecr_image_uri_prefix

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -283,86 +283,91 @@ def registry(region_name, algorithm=None):

https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/amazon
"""
if algorithm in [None, "pca", "kmeans", "linear-learner", "factorization-machines", "ntm",
"randomcutforest", "knn", "object2vec", "ipinsights"]:
if algorithm in [None, 'pca', 'kmeans', 'linear-learner', 'factorization-machines', 'ntm',
'randomcutforest', 'knn', 'object2vec', 'ipinsights']:
account_id = {
"us-east-1": "382416733822",
"us-east-2": "404615174143",
"us-west-2": "174872318107",
"eu-west-1": "438346466558",
"eu-central-1": "664544806723",
"ap-northeast-1": "351501993468",
"ap-northeast-2": "835164637446",
"ap-southeast-2": "712309505854",
"us-gov-west-1": "226302683700",
"ap-southeast-1": "475088953585",
"ap-south-1": "991648021394",
"ca-central-1": "469771592824",
"eu-west-2": "644912444149",
"us-west-1": "632365934929",
'us-east-1': '382416733822',
'us-east-2': '404615174143',
'us-west-2': '174872318107',
'eu-west-1': '438346466558',
'eu-central-1': '664544806723',
'ap-northeast-1': '351501993468',
'ap-northeast-2': '835164637446',
'ap-southeast-2': '712309505854',
'us-gov-west-1': '226302683700',
'ap-southeast-1': '475088953585',
'ap-south-1': '991648021394',
'ca-central-1': '469771592824',
'eu-west-2': '644912444149',
'us-west-1': '632365934929',
'us-iso-east-1': '490574956308',
}[region_name]
elif algorithm in ["lda"]:
elif algorithm in ['lda']:
account_id = {
"us-east-1": "766337827248",
"us-east-2": "999911452149",
"us-west-2": "266724342769",
"eu-west-1": "999678624901",
"eu-central-1": "353608530281",
"ap-northeast-1": "258307448986",
"ap-northeast-2": "293181348795",
"ap-southeast-2": "297031611018",
"us-gov-west-1": "226302683700",
"ap-southeast-1": "475088953585",
"ap-south-1": "991648021394",
"ca-central-1": "469771592824",
"eu-west-2": "644912444149",
"us-west-1": "632365934929",
'us-east-1': '766337827248',
'us-east-2': '999911452149',
'us-west-2': '266724342769',
'eu-west-1': '999678624901',
'eu-central-1': '353608530281',
'ap-northeast-1': '258307448986',
'ap-northeast-2': '293181348795',
'ap-southeast-2': '297031611018',
'us-gov-west-1': '226302683700',
'ap-southeast-1': '475088953585',
'ap-south-1': '991648021394',
'ca-central-1': '469771592824',
'eu-west-2': '644912444149',
'us-west-1': '632365934929',
'us-iso-east-1': '490574956308',
}[region_name]
elif algorithm in ["forecasting-deepar"]:
elif algorithm in ['forecasting-deepar']:
account_id = {
"us-east-1": "522234722520",
"us-east-2": "566113047672",
"us-west-2": "156387875391",
"eu-west-1": "224300973850",
"eu-central-1": "495149712605",
"ap-northeast-1": "633353088612",
"ap-northeast-2": "204372634319",
"ap-southeast-2": "514117268639",
"us-gov-west-1": "226302683700",
"ap-southeast-1": "475088953585",
"ap-south-1": "991648021394",
"ca-central-1": "469771592824",
"eu-west-2": "644912444149",
"us-west-1": "632365934929",
'us-east-1': '522234722520',
'us-east-2': '566113047672',
'us-west-2': '156387875391',
'eu-west-1': '224300973850',
'eu-central-1': '495149712605',
'ap-northeast-1': '633353088612',
'ap-northeast-2': '204372634319',
'ap-southeast-2': '514117268639',
'us-gov-west-1': '226302683700',
'ap-southeast-1': '475088953585',
'ap-south-1': '991648021394',
'ca-central-1': '469771592824',
'eu-west-2': '644912444149',
'us-west-1': '632365934929',
'us-iso-east-1': '490574956308',
}[region_name]
elif algorithm in ["xgboost", "seq2seq", "image-classification", "blazingtext",
"object-detection", "semantic-segmentation"]:
elif algorithm in ['xgboost', 'seq2seq', 'image-classification', 'blazingtext',
'object-detection', 'semantic-segmentation']:
account_id = {
"us-east-1": "811284229777",
"us-east-2": "825641698319",
"us-west-2": "433757028032",
"eu-west-1": "685385470294",
"eu-central-1": "813361260812",
"ap-northeast-1": "501404015308",
"ap-northeast-2": "306986355934",
"ap-southeast-2": "544295431143",
"us-gov-west-1": "226302683700",
"ap-southeast-1": "475088953585",
"ap-south-1": "991648021394",
"ca-central-1": "469771592824",
"eu-west-2": "644912444149",
"us-west-1": "632365934929",
'us-east-1': '811284229777',
'us-east-2': '825641698319',
'us-west-2': '433757028032',
'eu-west-1': '685385470294',
'eu-central-1': '813361260812',
'ap-northeast-1': '501404015308',
'ap-northeast-2': '306986355934',
'ap-southeast-2': '544295431143',
'us-gov-west-1': '226302683700',
'ap-southeast-1': '475088953585',
'ap-south-1': '991648021394',
'ca-central-1': '469771592824',
'eu-west-2': '644912444149',
'us-west-1': '632365934929',
'us-iso-east-1': '490574956308',
}[region_name]
elif algorithm in ['image-classification-neo', 'xgboost-neo']:
account_id = {
'us-west-2': '301217895009',
'us-east-1': '785573368785',
'eu-west-1': '802834080501',
'us-east-2': '007439368137'
'us-east-2': '007439368137',
}[region_name]
else:
raise ValueError("Algorithm class:{} doesn't have mapping to account_id with images".format(algorithm))
return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name)
raise ValueError('Algorithm class:{} does not have mapping to account_id with images'.format(algorithm))

return get_ecr_image_uri_prefix(account_id, region_name)


def get_image_uri(region_name, repo_name, repo_version=1):
Expand Down
8 changes: 7 additions & 1 deletion src/sagemaker/fw_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from __future__ import absolute_import
import logging

from sagemaker.utils import get_ecr_image_uri_prefix

image_registry_map = {
"us-west-1": {
"sparkml-serving": "746614075791",
Expand Down Expand Up @@ -69,6 +71,10 @@
"us-gov-west-1": {
"sparkml-serving": "414596584902",
"scikit-learn": "414596584902"
},
"us-iso-east-1": {
"sparkml-serving": "833128469047",
"scikit-learn": "833128469047"
}
}

Expand All @@ -80,7 +86,7 @@ def registry(region_name, framework=None):
"""
try:
account_id = image_registry_map[region_name][framework]
return "{}.dkr.ecr.{}.amazonaws.com".format(account_id, region_name)
return get_ecr_image_uri_prefix(account_id, region_name)
except KeyError:
logging.error("The specific image or region does not exist")
raise
Expand Down
9 changes: 6 additions & 3 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import tempfile
from six.moves.urllib.parse import urlparse

from sagemaker.utils import get_ecr_image_uri_prefix

_TAR_SOURCE_FILENAME = 'source.tar.gz'

UploadedCode = namedtuple('UserCode', ['s3_prefix', 'script_name'])
Expand All @@ -40,7 +42,8 @@

VALID_PY_VERSIONS = ['py2', 'py3']
VALID_EIA_FRAMEWORKS = ['tensorflow', 'mxnet']
VALID_ACCOUNTS_BY_REGION = {'us-gov-west-1': '246785580436'}
VALID_ACCOUNTS_BY_REGION = {'us-gov-west-1': '246785580436',
'us-iso-east-1': '744548109606'}


def create_image_uri(region, framework, instance_type, framework_version, py_version=None,
Expand Down Expand Up @@ -96,8 +99,8 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver
optimized_families=optimized_families):
framework += '-eia'

return "{}.dkr.ecr.{}.amazonaws.com/sagemaker-{}:{}" \
.format(account, region, framework, tag)
return "{}/sagemaker-{}:{}" \
.format(get_ecr_image_uri_prefix(account, region), framework, tag)


def _accelerator_type_valid_for_framework(framework, accelerator_type=None, optimized_families=None):
Expand Down
14 changes: 14 additions & 0 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,20 @@ def download_file(bucket_name, path, target, sagemaker_session):
bucket.download_file(path, target)


def get_ecr_image_uri_prefix(account, region):
"""get prefix of ECR image URI

Args:
account (str): AWS account number
region (str): AWS region name

Returns:
(str): URI prefix of ECR image
"""
domain = 'c2s.ic.gov' if region == 'us-iso-east-1' else 'amazonaws.com'
return '{}.dkr.ecr.{}.{}'.format(account, region, domain)


class DeferredError(object):
"""Stores an exception and raises it at a later time if this
object is accessed in any way. Useful to allow soft-dependencies on imports,
Expand Down
10 changes: 9 additions & 1 deletion tests/unit/test_amazon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

# Use PCA as a test implementation of AmazonAlgorithmEstimator
from sagemaker.amazon.pca import PCA
from sagemaker.amazon.amazon_estimator import upload_numpy_to_s3_shards, _build_shards, registry
from sagemaker.amazon.amazon_estimator import upload_numpy_to_s3_shards, _build_shards, registry, get_image_uri

COMMON_ARGS = {'role': 'myrole', 'train_instance_count': 1, 'train_instance_type': 'ml.c4.xlarge'}

Expand Down Expand Up @@ -61,6 +61,14 @@ def sagemaker_session():
return sms


def test_gov_ecr_uri():
assert get_image_uri('us-gov-west-1', 'kmeans', 'latest') == \
'226302683700.dkr.ecr.us-gov-west-1.amazonaws.com/kmeans:latest'

assert get_image_uri('us-iso-east-1', 'kmeans', 'latest') == \
'490574956308.dkr.ecr.us-iso-east-1.c2s.ic.gov/kmeans:latest'


def test_init(sagemaker_session):
pca = PCA(num_components=55, sagemaker_session=sagemaker_session, **COMMON_ARGS)
assert pca.num_components == 55
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_fw_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def test_registry_sparkml_serving():
assert registry('eu-central-1', 'sparkml-serving') == "492215442770.dkr.ecr.eu-central-1.amazonaws.com"
assert registry('ca-central-1', 'sparkml-serving') == "341280168497.dkr.ecr.ca-central-1.amazonaws.com"
assert registry('us-gov-west-1', 'sparkml-serving') == "414596584902.dkr.ecr.us-gov-west-1.amazonaws.com"
assert registry('us-iso-east-1', 'sparkml-serving') == "833128469047.dkr.ecr.us-iso-east-1.c2s.ic.gov"


def test_registry_sklearn():
Expand All @@ -57,6 +58,7 @@ def test_registry_sklearn():
assert registry('eu-central-1', scikit_learn_framework_name) == "492215442770.dkr.ecr.eu-central-1.amazonaws.com"
assert registry('ca-central-1', scikit_learn_framework_name) == "341280168497.dkr.ecr.ca-central-1.amazonaws.com"
assert registry('us-gov-west-1', scikit_learn_framework_name) == "414596584902.dkr.ecr.us-gov-west-1.amazonaws.com"
assert registry('us-iso-east-1', scikit_learn_framework_name) == "833128469047.dkr.ecr.us-iso-east-1.c2s.ic.gov"


def test_default_sklearn_image_uri():
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ def test_create_image_uri_cpu():
image_uri = fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'local', '1.0rc', 'py2', '23')
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2'

image_uri = fw_utils.create_image_uri('us-gov-west-1', MOCK_FRAMEWORK, 'ml.c4.large', '1.0rc', 'py2', '23')
assert image_uri == '246785580436.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2'

image_uri = fw_utils.create_image_uri('us-iso-east-1', MOCK_FRAMEWORK, 'ml.c4.large', '1.0rc', 'py2', '23')
assert image_uri == '744548109606.dkr.ecr.us-iso-east-1.c2s.ic.gov/sagemaker-mlfw:1.0rc-cpu-py2'


def test_create_image_uri_no_python():
image_uri = fw_utils.create_image_uri(MOCK_REGION, MOCK_FRAMEWORK, 'ml.c4.large', '1.0rc', account='23')
Expand Down