Skip to content

Commit c5890d2

Browse files
authored
feature: support cn-north-1 and cn-northwest-1 (#1380)
1 parent 54f31ee commit c5890d2

18 files changed

+122
-39
lines changed

.github/PULL_REQUEST_TEMPLATE.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ _Put an `x` in the boxes that apply. You can also fill these out after creating
1212

1313
- [ ] I have read the [CONTRIBUTING](https://github.com/aws/sagemaker-python-sdk/blob/master/CONTRIBUTING.md) doc
1414
- [ ] I used the commit message format described in [CONTRIBUTING](https://github.com/aws/sagemaker-python-sdk/blob/master/CONTRIBUTING.md#committing-your-change)
15-
- [ ] I have used the regional endpoint when creating S3 and/or STS clients (if appropriate)
15+
- [ ] I have passed the region in to any/all clients that I've initialized as part of this change.
1616
- [ ] I have updated any necessary documentation, including [READMEs](https://github.com/aws/sagemaker-python-sdk/blob/master/README.rst) and [API docs](https://github.com/aws/sagemaker-python-sdk/tree/master/doc) (if appropriate)
1717

1818
#### Tests

src/sagemaker/amazon/amazon_estimator.py

+34-13
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,9 @@ def record_set(self, train, labels=None, channel="train", encrypt=False):
281281
RecordSet: A RecordSet referencing the encoded, uploading training
282282
and label data.
283283
"""
284-
s3 = self.sagemaker_session.boto_session.resource("s3")
284+
s3 = self.sagemaker_session.boto_session.resource(
285+
"s3", region_name=self.sagemaker_session.boto_region_name
286+
)
285287
parsed_s3_url = urlparse(self.data_location)
286288
bucket, key_prefix = parsed_s3_url.netloc, parsed_s3_url.path
287289
key_prefix = key_prefix + "{}-{}/".format(type(self).__name__, sagemaker_timestamp())
@@ -467,9 +469,14 @@ def registry(region_name, algorithm=None):
467469
https://github.com/aws/sagemaker-python-sdk/tree/master/src/sagemaker/amazon
468470
469471
Args:
470-
region_name:
471-
algorithm:
472+
region_name (str): The region name for the account.
473+
algorithm (str): The algorithm for the account.
474+
475+
Raises:
476+
ValueError: If invalid algorithm passed in or if mapping does not exist for given algorithm
477+
and region.
472478
"""
479+
region_to_accounts = {}
473480
if algorithm in [
474481
None,
475482
"pca",
@@ -482,7 +489,7 @@ def registry(region_name, algorithm=None):
482489
"object2vec",
483490
"ipinsights",
484491
]:
485-
account_id = {
492+
region_to_accounts = {
486493
"us-east-1": "382416733822",
487494
"us-east-2": "404615174143",
488495
"us-west-2": "174872318107",
@@ -503,9 +510,11 @@ def registry(region_name, algorithm=None):
503510
"eu-west-3": "749696950732",
504511
"sa-east-1": "855470959533",
505512
"me-south-1": "249704162688",
506-
}[region_name]
513+
"cn-north-1": "390948362332",
514+
"cn-northwest-1": "387376663083",
515+
}
507516
elif algorithm in ["lda"]:
508-
account_id = {
517+
region_to_accounts = {
509518
"us-east-1": "766337827248",
510519
"us-east-2": "999911452149",
511520
"us-west-2": "266724342769",
@@ -521,9 +530,9 @@ def registry(region_name, algorithm=None):
521530
"eu-west-2": "644912444149",
522531
"us-west-1": "632365934929",
523532
"us-iso-east-1": "490574956308",
524-
}[region_name]
533+
}
525534
elif algorithm in ["forecasting-deepar"]:
526-
account_id = {
535+
region_to_accounts = {
527536
"us-east-1": "522234722520",
528537
"us-east-2": "566113047672",
529538
"us-west-2": "156387875391",
@@ -544,7 +553,9 @@ def registry(region_name, algorithm=None):
544553
"eu-west-3": "749696950732",
545554
"sa-east-1": "855470959533",
546555
"me-south-1": "249704162688",
547-
}[region_name]
556+
"cn-north-1": "390948362332",
557+
"cn-northwest-1": "387376663083",
558+
}
548559
elif algorithm in [
549560
"xgboost",
550561
"seq2seq",
@@ -553,7 +564,7 @@ def registry(region_name, algorithm=None):
553564
"object-detection",
554565
"semantic-segmentation",
555566
]:
556-
account_id = {
567+
region_to_accounts = {
557568
"us-east-1": "811284229777",
558569
"us-east-2": "825641698319",
559570
"us-west-2": "433757028032",
@@ -574,15 +585,25 @@ def registry(region_name, algorithm=None):
574585
"eu-west-3": "749696950732",
575586
"sa-east-1": "855470959533",
576587
"me-south-1": "249704162688",
577-
}[region_name]
588+
"cn-north-1": "390948362332",
589+
"cn-northwest-1": "387376663083",
590+
}
578591
elif algorithm in ["image-classification-neo", "xgboost-neo"]:
579-
account_id = NEO_IMAGE_ACCOUNT[region_name]
592+
region_to_accounts = NEO_IMAGE_ACCOUNT
580593
else:
581594
raise ValueError(
582595
"Algorithm class:{} does not have mapping to account_id with images".format(algorithm)
583596
)
584597

585-
return get_ecr_image_uri_prefix(account_id, region_name)
598+
if region_name in region_to_accounts:
599+
account_id = region_to_accounts[region_name]
600+
return get_ecr_image_uri_prefix(account_id, region_name)
601+
602+
raise ValueError(
603+
"Algorithm ({algorithm}) is unsupported for region ({region_name}).".format(
604+
algorithm=algorithm, region_name=region_name
605+
)
606+
)
586607

587608

588609
def get_image_uri(region_name, repo_name, repo_version=1):

src/sagemaker/debugger.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import smdebug_rulesconfig as rule_configs # noqa: F401 # pylint: disable=unused-import
2525

26+
from sagemaker.utils import get_ecr_image_uri_prefix
2627

2728
RULES_ECR_REPO_NAME = "sagemaker-debugger-rules"
2829

@@ -45,6 +46,8 @@
4546
"ap-southeast-1": {RULES_ECR_REPO_NAME: "972752614525"},
4647
"ap-southeast-2": {RULES_ECR_REPO_NAME: "184798709955"},
4748
"ca-central-1": {RULES_ECR_REPO_NAME: "519511493484"},
49+
"cn-north-1": {RULES_ECR_REPO_NAME: "618459771430"},
50+
"cn-northwest-1": {RULES_ECR_REPO_NAME: "658757709296"},
4851
}
4952

5053

@@ -59,7 +62,8 @@ def get_rule_container_image_uri(region):
5962
str: Formatted image uri for the given region and the rule container type
6063
"""
6164
registry_id = SAGEMAKER_RULE_CONTAINERS_ACCOUNTS_MAP.get(region).get(RULES_ECR_REPO_NAME)
62-
return "{}.dkr.ecr.{}.amazonaws.com/{}:latest".format(registry_id, region, RULES_ECR_REPO_NAME)
65+
image_uri_prefix = get_ecr_image_uri_prefix(registry_id, region)
66+
return "{}/{}:latest".format(image_uri_prefix, RULES_ECR_REPO_NAME)
6367

6468

6569
class Rule(object):

src/sagemaker/fw_registry.py

+10
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,16 @@
117117
"scikit-learn": "801668240914",
118118
"xgboost": "801668240914",
119119
},
120+
"cn-north-1": {
121+
"sparkml-serving": "450853457545",
122+
"scikit-learn": "450853457545",
123+
"xgboost": "450853457545",
124+
},
125+
"cn-northwest-1": {
126+
"sparkml-serving": "451049120500",
127+
"scikit-learn": "451049120500",
128+
"xgboost": "451049120500",
129+
},
120130
}
121131

122132

src/sagemaker/fw_utils.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,18 @@
7171
"pytorch-serving",
7272
]
7373
PY2_RESTRICTED_EIA_FRAMEWORKS = ["pytorch-serving"]
74-
VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "246785580436", "us-iso-east-1": "744548109606"}
75-
ASIMOV_VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074"}
74+
VALID_ACCOUNTS_BY_REGION = {
75+
"us-gov-west-1": "246785580436",
76+
"us-iso-east-1": "744548109606",
77+
"cn-north-1": "422961961927",
78+
"cn-northwest-1": "423003514399",
79+
}
80+
ASIMOV_VALID_ACCOUNTS_BY_REGION = {
81+
"us-gov-west-1": "442386744353",
82+
"us-iso-east-1": "886529160074",
83+
"cn-north-1": "727897471807",
84+
"cn-northwest-1": "727897471807",
85+
}
7686
OPT_IN_ACCOUNTS_BY_REGION = {"ap-east-1": "057415533634", "me-south-1": "724002660598"}
7787
ASIMOV_OPT_IN_ACCOUNTS_BY_REGION = {"ap-east-1": "871362719292", "me-south-1": "217643126080"}
7888
DEFAULT_ACCOUNT = "520713654638"

src/sagemaker/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
"sa-east-1": "756306329178",
4848
"ca-central-1": "464438896020",
4949
"me-south-1": "836785723513",
50+
"cn-north-1": "472730292857",
51+
"cn-northwest-1": "474822919863",
5052
"us-gov-west-1": "263933020539",
5153
}
5254

src/sagemaker/model_monitor/model_monitoring.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,9 @@
3232
from sagemaker.processing import Processor, ProcessingInput, ProcessingJob, ProcessingOutput
3333
from sagemaker.s3 import S3Uploader
3434
from sagemaker.session import Session
35-
from sagemaker.utils import name_from_base, retries
35+
from sagemaker.utils import name_from_base, retries, get_ecr_image_uri_prefix
3636

37-
_DEFAULT_MONITOR_IMAGE_URI_WITH_PLACEHOLDERS = (
38-
"{}.dkr.ecr.{}.amazonaws.com/sagemaker-model-monitor-analyzer"
39-
)
37+
_DEFAULT_MONITOR_IMAGE_URI_WITH_PLACEHOLDERS = "{}/sagemaker-model-monitor-analyzer"
4038

4139
_DEFAULT_MONITOR_IMAGE_REGION_ACCOUNT_MAPPING = {
4240
"eu-north-1": "895015795356",
@@ -57,6 +55,8 @@
5755
"ap-southeast-1": "245545462676",
5856
"ap-southeast-2": "563025443158",
5957
"ca-central-1": "536280801234",
58+
"cn-north-1": "453000072557",
59+
"cn-northwest-1": "453252182341",
6060
}
6161

6262
STATISTICS_JSON_DEFAULT_FILE_NAME = "statistics.json"
@@ -1761,7 +1761,7 @@ def _get_default_image_uri(region):
17611761
str: The Default Model Monitoring image uri based on the region.
17621762
"""
17631763
return _DEFAULT_MONITOR_IMAGE_URI_WITH_PLACEHOLDERS.format(
1764-
_DEFAULT_MONITOR_IMAGE_REGION_ACCOUNT_MAPPING[region], region
1764+
get_ecr_image_uri_prefix(_DEFAULT_MONITOR_IMAGE_REGION_ACCOUNT_MAPPING[region], region)
17651765
)
17661766

17671767

src/sagemaker/utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,9 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
509509
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
510510
else:
511511
extra_args = None
512-
sagemaker_session.boto_session.resource("s3").Object(bucket, new_key).upload_file(
513-
tmp_model_path, ExtraArgs=extra_args
514-
)
512+
sagemaker_session.boto_session.resource(
513+
"s3", region_name=sagemaker_session.boto_region_name
514+
).Object(bucket, new_key).upload_file(tmp_model_path, ExtraArgs=extra_args)
515515
else:
516516
shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))
517517

@@ -604,7 +604,7 @@ def download_file(bucket_name, path, target, sagemaker_session):
604604
path = path.lstrip("/")
605605
boto_session = sagemaker_session.boto_session
606606

607-
s3 = boto_session.resource("s3")
607+
s3 = boto_session.resource("s3", region_name=sagemaker_session.boto_region_name)
608608
bucket = s3.Bucket(bucket_name)
609609
bucket.download_file(path, target)
610610

tests/integ/kms_utils.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def get_or_create_kms_key(
141141
"Resource": "arn:{partition}:s3:::{bucket_name}/*",
142142
"Condition": {{
143143
"StringNotEquals": {{
144-
"s3:x-amz-server-side-encryption": "{partition}:kms"
144+
"s3:x-amz-server-side-encryption": "aws:kms"
145145
}}
146146
}}
147147
}},
@@ -172,7 +172,7 @@ def bucket_with_encryption(sagemaker_session, sagemaker_role):
172172
account = sts_client.get_caller_identity()["Account"]
173173
role_arn = sts_client.get_caller_identity()["Arn"]
174174

175-
kms_client = boto_session.client("kms")
175+
kms_client = boto_session.client("kms", region_name=region)
176176
kms_key_arn = _create_kms_key(kms_client, account, region, role_arn, sagemaker_role, None)
177177

178178
region = boto_session.region_name
@@ -187,9 +187,7 @@ def bucket_with_encryption(sagemaker_session, sagemaker_role):
187187
"Rules": [
188188
{
189189
"ApplyServerSideEncryptionByDefault": {
190-
"SSEAlgorithm": "{partition}:kms".format(
191-
partition=utils._aws_partition(region)
192-
),
190+
"SSEAlgorithm": "aws:kms",
193191
"KMSMasterKeyID": kms_key_arn,
194192
}
195193
}

tests/integ/marketplace_utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,7 @@
3030
"eu-north-1": "136758871317",
3131
"sa-east-1": "270155090741",
3232
"ap-east-1": "822005858737",
33+
"me-south-1": "335155493544",
34+
"cn-north-1": "295401494951",
35+
"cn-northwest-1": "304690803264",
3336
}

tests/integ/test_debugger.py

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
"us-east-2": "840043622174",
5454
"us-west-1": "952348334681",
5555
"us-west-2": "759209512951",
56+
"cn-north-1": "617202126805",
57+
"cn-northwest-1": "658559488188",
5658
}
5759

5860
# TODO-reinvent-2019: test get_debugger_artifacts_path and get_tensorboard_artifacts_path

tests/integ/test_horovod.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ def read_json(file, tmp):
8686
return json.load(f)
8787

8888

89-
def extract_files_from_s3(s3_url, tmpdir):
89+
def extract_files_from_s3(s3_url, tmpdir, sagemaker_session):
9090
parsed_url = urlparse(s3_url)
91-
s3 = boto3.resource("s3")
91+
s3 = boto3.resource("s3", region_name=sagemaker_session.boto_region_name)
9292

9393
model = os.path.join(tmpdir, "model")
9494
s3.Bucket(parsed_url.netloc).download_file(parsed_url.path.lstrip("/"), model)
@@ -115,7 +115,7 @@ def _create_and_fit_estimator(sagemaker_session, instance_type, tmpdir):
115115
estimator.fit(job_name=job_name)
116116

117117
tmp = str(tmpdir)
118-
extract_files_from_s3(estimator.model_data, tmp)
118+
extract_files_from_s3(estimator.model_data, tmp, sagemaker_session)
119119

120120
for rank in range(2):
121121
assert read_json("rank-%s" % rank, tmp)["rank"] == rank

tests/integ/test_model_monitor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
DEFAULT_VOLUME_SIZE_IN_GB = 30
6262
DEFAULT_BASELINING_MAX_RUNTIME_IN_SECONDS = 86400
6363
DEFAULT_EXECUTION_MAX_RUNTIME_IN_SECONDS = 3600
64-
DEFAULT_IMAGE_SUFFIX = ".com/sagemaker-model-monitor-analyzer"
64+
DEFAULT_IMAGE_SUFFIX = "/sagemaker-model-monitor-analyzer"
6565

6666
UPDATED_ROLE = "SageMakerRole"
6767
UPDATED_INSTANCE_COUNT = 2

tests/integ/test_multidatamodel.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from sagemaker.multidatamodel import MultiDataModel
2828
from sagemaker.mxnet import MXNet
2929
from sagemaker.predictor import RealTimePredictor, StringDeserializer, npy_serializer
30-
from sagemaker.utils import sagemaker_timestamp, unique_name_from_base
30+
from sagemaker.utils import sagemaker_timestamp, unique_name_from_base, get_ecr_image_uri_prefix
3131
from tests.integ import DATA_DIR, PYTHON_VERSION, TRAINING_DEFAULT_TIMEOUT_MINUTES
3232
from tests.integ.retry import retries
3333
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
@@ -49,8 +49,9 @@ def container_image(sagemaker_session):
4949
)
5050
account_id = sts_client.get_caller_identity()["Account"]
5151
algorithm_name = "sagemaker-multimodel-integ-test-{}".format(sagemaker_timestamp())
52-
ecr_image = "{account}.dkr.ecr.{region}.amazonaws.com/{algorithm_name}:latest".format(
53-
account=account_id, region=region, algorithm_name=algorithm_name
52+
ecr_image_uri_prefix = get_ecr_image_uri_prefix(account=account_id, region=region)
53+
ecr_image = "{prefix}/{algorithm_name}:latest".format(
54+
prefix=ecr_image_uri_prefix, algorithm_name=algorithm_name
5455
)
5556

5657
# Build and tag docker image locally

tests/integ/test_session.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,5 @@ def test_sagemaker_session_does_not_create_bucket_on_init(
4646
default_bucket=CUSTOM_BUCKET_NAME,
4747
)
4848

49-
s3 = boto3.resource("s3")
49+
s3 = boto3.resource("s3", region_name=DEFAULT_REGION)
5050
assert s3.Bucket(CUSTOM_BUCKET_NAME).creation_date is None

tests/unit/test_amazon_estimator.py

+12
Original file line numberDiff line numberDiff line change
@@ -462,3 +462,15 @@ def test_get_xgboost_image_uri():
462462
updated_xgb_image_uri_v2
463463
== "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-xgboost:0.90-2-cpu-py3"
464464
)
465+
466+
467+
def test_regitry_throws_error_if_mapping_does_not_exist_for_lda():
468+
with pytest.raises(ValueError) as error:
469+
registry("cn-north-1", "lda")
470+
assert "Algorithm (lda) is unsupported for region (cn-north-1)." in str(error)
471+
472+
473+
def test_regitry_throws_error_if_mapping_does_not_exist_for_default_algorithm():
474+
with pytest.raises(ValueError) as error:
475+
registry("broken_region_name")
476+
assert "Algorithm (None) is unsupported for region (broken_region_name)." in str(error)

tests/unit/test_fw_utils.py

+18
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,24 @@ def test_create_image_uri_bah():
279279
}
280280

281281

282+
def test_create_image_uri_cn_north_1():
283+
image_uri = fw_utils.create_image_uri(
284+
"cn-north-1", MOCK_FRAMEWORK, "ml.p3.2xlarge", "1.0rc", "py3"
285+
)
286+
assert {
287+
image_uri == "727897471807.dkr.ecr.me-south-1.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3"
288+
}
289+
290+
291+
def test_create_image_uri_cn_northwest_1():
292+
image_uri = fw_utils.create_image_uri(
293+
"cn-northwest-1", MOCK_FRAMEWORK, "ml.p3.2xlarge", "1.0rc", "py3"
294+
)
295+
assert {
296+
image_uri == "727897471807.dkr.ecr.me-south-1.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3"
297+
}
298+
299+
282300
def test_tf_eia_images():
283301
image_uri = fw_utils.create_image_uri(
284302
"us-west-2",

0 commit comments

Comments
 (0)