From 2d96997be510f4e8f482313d5ae0989e5dafdfdf Mon Sep 17 00:00:00 2001 From: zxy844288792 Date: Tue, 14 May 2019 15:07:57 -0700 Subject: [PATCH] feature: add region check for Neo service --- src/sagemaker/model.py | 13 +++++++++++++ tests/unit/test_model.py | 16 ++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 80b809363f..d81c35fcbc 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -149,6 +149,19 @@ def _compilation_job_config(self, target_instance_type, input_shape, output_path 'tags': tags, 'job_name': job_name} + def check_neo_region(self, region): + """Check if this ``Model`` in the available region where neo support. + + Args: + region (str): Specifies the region where want to execute compilation + Returns: + bool: boolean value whether if neo is available in the specified region + """ + if region in NEO_IMAGE_ACCOUNT: + return True + else: + return False + def _neo_image_account(self, region): if region not in NEO_IMAGE_ACCOUNT: raise ValueError("Neo is not currently supported in {}, " diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index 1ba4fb28f9..d47a8fef90 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -400,3 +400,19 @@ def test_compile_model_for_cloud(sagemaker_session, tmpdir): model.compile(target_instance_family='ml_c4', input_shape={'data': [1, 3, 1024, 1024]}, output_path='s3://output', role='role', framework='tensorflow', job_name="compile-model") assert model._is_compiled_model is True + + +def test_check_neo_region(sagemaker_session, tmpdir): + sagemaker_session.wait_for_compilation_job = Mock( + return_value=DESCRIBE_COMPILATION_JOB_RESPONSE) + model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) + ec2_region_list = ['us-east-2', 'us-east-1', 'us-west-1', 'us-west-2', 'ap-east-1', 'ap-south-1', + 'ap-northeast-3', 'ap-northeast-2', 'ap-southeast-1', 'ap-southeast-2', 'ap-northeast-1', + 'ca-central-1', 'cn-north-1', 'cn-northwest-1', 'eu-central-1', ' eu-west-1', 'eu-west-2', + 'eu-west-3', 'eu-north-1', 'sa-east-1', 'us-gov-east-1', 'us-gov-west-1'] + neo_support_region = ['us-west-2', 'eu-west-1', 'us-east-1', 'us-east-2'] + for region_name in ec2_region_list: + if region_name in neo_support_region: + assert model.check_neo_region(region_name) is True + else: + assert model.check_neo_region(region_name) is False