Skip to content

Commit b1df490

Browse files
wweicyangaws
authored andcommitted
Add warning message if user tries to compile a model for edge device (#659)
1 parent 14be041 commit b1df490

File tree

2 files changed

+35
-3
lines changed

2 files changed

+35
-3
lines changed

src/sagemaker/model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from sagemaker import fw_utils, local, session, utils
2020
from sagemaker.transformer import Transformer
2121

22+
logging.basicConfig()
23+
LOGGER = logging.getLogger('sagemaker')
24+
LOGGER.setLevel(logging.INFO)
2225

2326
NEO_ALLOWED_TARGET_INSTANCE_FAMILY = set(['ml_c5', 'ml_m5', 'ml_c4', 'ml_m4', 'jetson_tx1', 'jetson_tx2', 'ml_p2',
2427
'ml_p3', 'deeplens', 'rasp3b'])
@@ -190,9 +193,13 @@ def compile(self, target_instance_family, input_shape, output_path, role,
190193
self.sagemaker_session.compile_model(**config)
191194
job_status = self.sagemaker_session.wait_for_compilation_job(job_name)
192195
self.model_data = job_status['ModelArtifacts']['S3ModelArtifacts']
193-
self.image = self._neo_image(self.sagemaker_session.boto_region_name, target_instance_family, framework,
194-
framework_version)
195-
self._is_compiled_model = True
196+
if target_instance_family.startswith('ml_'):
197+
self.image = self._neo_image(self.sagemaker_session.boto_region_name, target_instance_family, framework,
198+
framework_version)
199+
self._is_compiled_model = True
200+
else:
201+
LOGGER.warning("The intance type {} is not supported to deploy via SageMaker,"
202+
"please deploy the model on the device by yourself.".format(target_instance_family))
196203
return self
197204

198205
def deploy(self, initial_instance_count, instance_type, accelerator_type=None, endpoint_name=None,

tests/unit/test_model.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@
8888
'CertifyForMarketplace': False
8989
}
9090

91+
DESCRIBE_COMPILATION_JOB_RESPONSE = {
92+
'CompilationJobStatus': "Completed",
93+
'ModelArtifacts': {
94+
'S3ModelArtifacts': 's3://output-path/model.tar.gz'
95+
}
96+
}
97+
9198

9299
class DummyFrameworkModel(FrameworkModel):
93100

@@ -351,3 +358,21 @@ def test_delete_non_deployed_model(sagemaker_session):
351358
model = DummyFrameworkModel(sagemaker_session)
352359
with pytest.raises(ValueError, match='The SageMaker model must be created first before attempting to delete.'):
353360
model.delete_model()
361+
362+
363+
def test_compile_model_for_edge_device(sagemaker_session, tmpdir):
364+
sagemaker_session.wait_for_compilation_job = Mock(
365+
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE)
366+
model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir))
367+
model.compile(target_instance_family='deeplens', input_shape={'data': [1, 3, 1024, 1024]},
368+
output_path='s3://output', role='role', framework='tensorflow', job_name="compile-model")
369+
assert model._is_compiled_model is False
370+
371+
372+
def test_compile_model_for_cloud(sagemaker_session, tmpdir):
373+
sagemaker_session.wait_for_compilation_job = Mock(
374+
return_value=DESCRIBE_COMPILATION_JOB_RESPONSE)
375+
model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir))
376+
model.compile(target_instance_family='ml_c4', input_shape={'data': [1, 3, 1024, 1024]},
377+
output_path='s3://output', role='role', framework='tensorflow', job_name="compile-model")
378+
assert model._is_compiled_model is True

0 commit comments

Comments
 (0)