Skip to content

Commit a39efe4

Browse files
Xingyu Zhouchuyang-deng
Xingyu Zhou
authored andcommitted
feature: add region check for Neo service (#806)
1 parent 69a1685 commit a39efe4

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

src/sagemaker/model.py

+13
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,19 @@ def _compilation_job_config(self, target_instance_type, input_shape, output_path
151151
'tags': tags,
152152
'job_name': job_name}
153153

154+
def check_neo_region(self, region):
155+
"""Check if this ``Model`` in the available region where neo support.
156+
157+
Args:
158+
region (str): Specifies the region where want to execute compilation
159+
Returns:
160+
bool: boolean value whether if neo is available in the specified region
161+
"""
162+
if region in NEO_IMAGE_ACCOUNT:
163+
return True
164+
else:
165+
return False
166+
154167
def _neo_image_account(self, region):
155168
if region not in NEO_IMAGE_ACCOUNT:
156169
raise ValueError("Neo is not currently supported in {}, "

tests/unit/test_model.py

+16
Original file line numberDiff line numberDiff line change
@@ -400,3 +400,19 @@ def test_compile_model_for_cloud(sagemaker_session, tmpdir):
400400
model.compile(target_instance_family='ml_c4', input_shape={'data': [1, 3, 1024, 1024]},
401401
output_path='s3://output', role='role', framework='tensorflow', job_name="compile-model")
402402
assert model._is_compiled_model is True
403+
404+
405+
def test_check_neo_region(sagemaker_session, tmpdir):
406+
sagemaker_session.wait_for_compilation_job = Mock(
407+
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE)
408+
model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir))
409+
ec2_region_list = ['us-east-2', 'us-east-1', 'us-west-1', 'us-west-2', 'ap-east-1', 'ap-south-1',
410+
'ap-northeast-3', 'ap-northeast-2', 'ap-southeast-1', 'ap-southeast-2', 'ap-northeast-1',
411+
'ca-central-1', 'cn-north-1', 'cn-northwest-1', 'eu-central-1', ' eu-west-1', 'eu-west-2',
412+
'eu-west-3', 'eu-north-1', 'sa-east-1', 'us-gov-east-1', 'us-gov-west-1']
413+
neo_support_region = ['us-west-2', 'eu-west-1', 'us-east-1', 'us-east-2']
414+
for region_name in ec2_region_list:
415+
if region_name in neo_support_region:
416+
assert model.check_neo_region(region_name) is True
417+
else:
418+
assert model.check_neo_region(region_name) is False

0 commit comments

Comments
 (0)