Skip to content

Commit 1124993

Browse files
yangawsPiali Das
authored and
Piali Das
committed
Add gov cloud account number for framework (aws#400)
1 parent 400f25e commit 1124993

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

src/sagemaker/fw_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver
4949
str: The appropriate image URI based on the given parameters.
5050
"""
5151

52+
# Handle Account Number for Gov Cloud
53+
if region == 'us-gov-west-1':
54+
account = '246785580436'
55+
5256
# Handle Local Mode
5357
if instance_type.startswith('local'):
5458
device_type = 'cpu' if instance_type == 'local' else 'gpu'

tests/unit/test_fw_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def test_create_image_uri_default_account():
6262
assert image_uri == '520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3'
6363

6464

65+
def test_create_image_uri_gov_cloud():
66+
image_uri = create_image_uri('us-gov-west-1', 'mlfw', 'ml.p3.2xlarge', '1.0rc', 'py3')
67+
assert image_uri == '246785580436.dkr.ecr.us-gov-west-1.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3'
68+
69+
6570
def test_invalid_instance_type():
6671
# instance type is missing 'ml.' prefix
6772
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)