20
20
MODEL_DATA = "s3://bucket/model.tar.gz"
21
21
MODEL_IMAGE = "mi"
22
22
23
+ IMAGE_URI = "inference-container-uri"
24
+
23
25
REGION = "us-west-2"
24
26
25
27
NEO_REGION_ACCOUNT = "301217895009"
26
28
DESCRIBE_COMPILATION_JOB_RESPONSE = {
27
29
"CompilationJobStatus" : "Completed" ,
28
30
"ModelArtifacts" : {"S3ModelArtifacts" : "s3://output-path/model.tar.gz" },
31
+ "InferenceImage" : IMAGE_URI ,
29
32
}
30
33
31
34
@@ -52,12 +55,7 @@ def test_compile_model_for_inferentia(sagemaker_session):
52
55
framework_version = "1.15.0" ,
53
56
job_name = "compile-model" ,
54
57
)
55
- assert (
56
- "{}.dkr.ecr.{}.amazonaws.com/sagemaker-neo-tensorflow:1.15.0-inf-py3" .format (
57
- NEO_REGION_ACCOUNT , REGION
58
- )
59
- == model .image_uri
60
- )
58
+ assert DESCRIBE_COMPILATION_JOB_RESPONSE ["InferenceImage" ] == model .image_uri
61
59
assert model ._is_compiled_model is True
62
60
63
61
@@ -271,11 +269,12 @@ def test_deploy_add_compiled_model_suffix_to_endpoint_name_from_model_name(sagem
271
269
assert model .endpoint_name .startswith ("{}-ml-c4" .format (model_name ))
272
270
273
271
274
- @patch ("sagemaker.session.Session" )
275
- def test_compile_with_framework_version_15 (session ):
276
- session .return_value .boto_region_name = REGION
272
+ def test_compile_with_framework_version_15 (sagemaker_session ):
273
+ sagemaker_session .wait_for_compilation_job = Mock (
274
+ return_value = DESCRIBE_COMPILATION_JOB_RESPONSE
275
+ )
277
276
278
- model = _create_model ()
277
+ model = _create_model (sagemaker_session )
279
278
model .compile (
280
279
target_instance_family = "ml_c4" ,
281
280
input_shape = {"data" : [1 , 3 , 1024 , 1024 ]},
@@ -286,14 +285,15 @@ def test_compile_with_framework_version_15(session):
286
285
job_name = "compile-model" ,
287
286
)
288
287
289
- assert "1.5" in model .image_uri
288
+ assert IMAGE_URI == model .image_uri
290
289
291
290
292
- @patch ("sagemaker.session.Session" )
293
- def test_compile_with_framework_version_16 (session ):
294
- session .return_value .boto_region_name = REGION
291
+ def test_compile_with_framework_version_16 (sagemaker_session ):
292
+ sagemaker_session .wait_for_compilation_job = Mock (
293
+ return_value = DESCRIBE_COMPILATION_JOB_RESPONSE
294
+ )
295
295
296
- model = _create_model ()
296
+ model = _create_model (sagemaker_session )
297
297
model .compile (
298
298
target_instance_family = "ml_c4" ,
299
299
input_shape = {"data" : [1 , 3 , 1024 , 1024 ]},
@@ -304,26 +304,7 @@ def test_compile_with_framework_version_16(session):
304
304
job_name = "compile-model" ,
305
305
)
306
306
307
- assert "1.6" in model .image_uri
308
-
309
-
310
- @patch ("sagemaker.session.Session" )
311
- def test_compile_validates_framework_version (session ):
312
- session .return_value .boto_region_name = REGION
313
-
314
- model = _create_model ()
315
- with pytest .raises (ValueError ) as e :
316
- model .compile (
317
- target_instance_family = "ml_c4" ,
318
- input_shape = {"data" : [1 , 3 , 1024 , 1024 ]},
319
- output_path = "s3://output" ,
320
- role = "role" ,
321
- framework = "pytorch" ,
322
- framework_version = "1.6.1" ,
323
- job_name = "compile-model" ,
324
- )
325
-
326
- assert "Unsupported neo-pytorch version: 1.6.1." in str (e )
307
+ assert IMAGE_URI == model .image_uri
327
308
328
309
329
310
@patch ("sagemaker.session.Session" )
@@ -347,3 +328,25 @@ def test_compile_with_pytorch_neo_in_ml_inf(session):
347
328
)
348
329
!= model .image_uri
349
330
)
331
+
332
+
333
+ def test_compile_validates_framework_version (sagemaker_session ):
334
+ sagemaker_session .wait_for_compilation_job = Mock (
335
+ return_value = {
336
+ "CompilationJobStatus" : "Completed" ,
337
+ "ModelArtifacts" : {"S3ModelArtifacts" : "s3://output-path/model.tar.gz" },
338
+ "InferenceImage" : None ,
339
+ }
340
+ )
341
+ model = _create_model (sagemaker_session )
342
+ model .compile (
343
+ target_instance_family = "ml_c4" ,
344
+ input_shape = {"data" : [1 , 3 , 1024 , 1024 ]},
345
+ output_path = "s3://output" ,
346
+ role = "role" ,
347
+ framework = "pytorch" ,
348
+ framework_version = "1.6.1" ,
349
+ job_name = "compile-model" ,
350
+ )
351
+
352
+ assert model .image_uri is None
0 commit comments