27
27
from mock import call , patch , Mock , MagicMock
28
28
29
29
import sagemaker
30
+ from sagemaker .session_settings import SessionSettings
30
31
31
32
BUCKET_WITHOUT_WRITING_PERMISSION = "s3://bucket-without-writing-permission"
32
33
@@ -390,6 +391,13 @@ def test_repack_model_without_source_dir(tmp, fake_s3):
390
391
"/code/inference.py" ,
391
392
}
392
393
394
+ extra_args = {"ServerSideEncryption" : "aws:kms" }
395
+ object_mock = fake_s3 .object_mock
396
+ _ , _ , kwargs = object_mock .mock_calls [0 ]
397
+
398
+ assert "ExtraArgs" in kwargs
399
+ assert kwargs ["ExtraArgs" ] == extra_args
400
+
393
401
394
402
def test_repack_model_with_entry_point_without_path_without_source_dir (tmp , fake_s3 ):
395
403
@@ -415,12 +423,20 @@ def test_repack_model_with_entry_point_without_path_without_source_dir(tmp, fake
415
423
"s3://fake/location" ,
416
424
"s3://destination-bucket/model.tar.gz" ,
417
425
fake_s3 .sagemaker_session ,
426
+ kms_key = "kms_key" ,
418
427
)
419
428
finally :
420
429
os .chdir (cwd )
421
430
422
431
assert list_tar_files (fake_s3 .fake_upload_path , tmp ) == {"/code/inference.py" , "/model" }
423
432
433
+ extra_args = {"ServerSideEncryption" : "aws:kms" , "SSEKMSKeyId" : "kms_key" }
434
+ object_mock = fake_s3 .object_mock
435
+ _ , _ , kwargs = object_mock .mock_calls [0 ]
436
+
437
+ assert "ExtraArgs" in kwargs
438
+ assert kwargs ["ExtraArgs" ] == extra_args
439
+
424
440
425
441
def test_repack_model_from_s3_to_s3 (tmp , fake_s3 ):
426
442
@@ -434,6 +450,7 @@ def test_repack_model_from_s3_to_s3(tmp, fake_s3):
434
450
)
435
451
436
452
fake_s3 .tar_and_upload ("model-dir" , "s3://fake/location" )
453
+ fake_s3 .sagemaker_session .settings = SessionSettings (encrypt_repacked_artifacts = False )
437
454
438
455
sagemaker .utils .repack_model (
439
456
"inference.py" ,
@@ -450,6 +467,11 @@ def test_repack_model_from_s3_to_s3(tmp, fake_s3):
450
467
"/model" ,
451
468
}
452
469
470
+ object_mock = fake_s3 .object_mock
471
+ _ , _ , kwargs = object_mock .mock_calls [0 ]
472
+ assert "ExtraArgs" in kwargs
473
+ assert kwargs ["ExtraArgs" ] is None
474
+
453
475
454
476
def test_repack_model_from_file_to_file (tmp ):
455
477
create_file_tree (tmp , ["model" , "dependencies/a" , "source-dir/inference.py" ])
@@ -581,6 +603,7 @@ def __init__(self, tmp):
581
603
self .sagemaker_session = MagicMock ()
582
604
self .location_map = {}
583
605
self .current_bucket = None
606
+ self .object_mock = MagicMock ()
584
607
585
608
self .sagemaker_session .boto_session .resource ().Bucket ().download_file .side_effect = (
586
609
self .download_file
@@ -606,6 +629,7 @@ def tar_and_upload(self, path, fake_location):
606
629
607
630
def mock_s3_upload (self ):
608
631
dst = os .path .join (self .tmp , "dst" )
632
+ object_mock = self .object_mock
609
633
610
634
class MockS3Object (object ):
611
635
def __init__ (self , bucket , key ):
@@ -616,6 +640,7 @@ def upload_file(self, target, **kwargs):
616
640
if self .bucket in BUCKET_WITHOUT_WRITING_PERMISSION :
617
641
raise exceptions .S3UploadFailedError ()
618
642
shutil .copy2 (target , dst )
643
+ object_mock .upload_file (target , ** kwargs )
619
644
620
645
self .sagemaker_session .boto_session .resource ().Object = MockS3Object
621
646
return dst
0 commit comments