Skip to content

Commit 51559f9

Browse files
author
zxy844288792
committed
feature: add region check for Neo service
1 parent fa8dd99 commit 51559f9

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

src/sagemaker/model.py

+13
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,19 @@ def _compilation_job_config(self, target_instance_type, input_shape, output_path
149149
'tags': tags,
150150
'job_name': job_name}
151151

152+
def check_neo_region(self, region):
153+
"""Check if this ``Model`` in the available region where neo support.
154+
155+
Args:
156+
region (str): Specifies the region where want to execute compilation
157+
Returns:
158+
bool: boolean value whether if neo is available in the specified region
159+
"""
160+
if region in NEO_IMAGE_ACCOUNT:
161+
return True
162+
else:
163+
return False
164+
152165
def _neo_image_account(self, region):
153166
if region not in NEO_IMAGE_ACCOUNT:
154167
raise ValueError("Neo is not currently supported in {}, "

tests/unit/test_model.py

+16
Original file line numberDiff line numberDiff line change
@@ -400,3 +400,19 @@ def test_compile_model_for_cloud(sagemaker_session, tmpdir):
400400
model.compile(target_instance_family='ml_c4', input_shape={'data': [1, 3, 1024, 1024]},
401401
output_path='s3://output', role='role', framework='tensorflow', job_name="compile-model")
402402
assert model._is_compiled_model is True
403+
404+
405+
def test_check_neo_region(sagemaker_session, tmpdir):
406+
sagemaker_session.wait_for_compilation_job = Mock(
407+
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE)
408+
model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir))
409+
ec2_region_list = ['us-east-2', 'us-east-1', 'us-west-1', 'us-west-2', 'ap-east-1', 'ap-south-1',
410+
'ap-northeast-3', 'ap-northeast-2', 'ap-southeast-1', 'ap-southeast-2', 'ap-northeast-1',
411+
'ca-central-1', 'cn-north-1', 'cn-northwest-1', 'eu-central-1', ' eu-west-1', 'eu-west-2',
412+
'eu-west-3', 'eu-north-1', 'sa-east-1', 'us-gov-east-1', 'us-gov-west-1']
413+
neo_support_region = ['us-west-2', 'eu-west-1', 'us-east-1', 'us-east-2']
414+
for region_name in ec2_region_list:
415+
if region_name in neo_support_region:
416+
assert model.check_neo_region(region_name) is True
417+
else:
418+
assert model.check_neo_region(region_name) is False

0 commit comments

Comments
 (0)