|
88 | 88 | 'CertifyForMarketplace': False
|
89 | 89 | }
|
90 | 90 |
|
| 91 | +DESCRIBE_COMPILATION_JOB_RESPONSE = { |
| 92 | + 'CompilationJobStatus': "Completed", |
| 93 | + 'ModelArtifacts': { |
| 94 | + 'S3ModelArtifacts': 's3://output-path/model.tar.gz' |
| 95 | + } |
| 96 | +} |
| 97 | + |
91 | 98 |
|
92 | 99 | class DummyFrameworkModel(FrameworkModel):
|
93 | 100 |
|
@@ -351,3 +358,21 @@ def test_delete_non_deployed_model(sagemaker_session):
|
351 | 358 | model = DummyFrameworkModel(sagemaker_session)
|
352 | 359 | with pytest.raises(ValueError, match='The SageMaker model must be created first before attempting to delete.'):
|
353 | 360 | model.delete_model()
|
| 361 | + |
| 362 | + |
| 363 | +def test_compile_model_for_edge_device(sagemaker_session, tmpdir): |
| 364 | + sagemaker_session.wait_for_compilation_job = Mock( |
| 365 | + return_value=DESCRIBE_COMPILATION_JOB_RESPONSE) |
| 366 | + model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) |
| 367 | + model.compile(target_instance_family='deeplens', input_shape={'data': [1, 3, 1024, 1024]}, |
| 368 | + output_path='s3://output', role='role', framework='tensorflow', job_name="compile-model") |
| 369 | + assert model._is_compiled_model is False |
| 370 | + |
| 371 | + |
| 372 | +def test_compile_model_for_cloud(sagemaker_session, tmpdir): |
| 373 | + sagemaker_session.wait_for_compilation_job = Mock( |
| 374 | + return_value=DESCRIBE_COMPILATION_JOB_RESPONSE) |
| 375 | + model = DummyFrameworkModel(sagemaker_session, source_dir=str(tmpdir)) |
| 376 | + model.compile(target_instance_family='ml_c4', input_shape={'data': [1, 3, 1024, 1024]}, |
| 377 | + output_path='s3://output', role='role', framework='tensorflow', job_name="compile-model") |
| 378 | + assert model._is_compiled_model is True |
0 commit comments