22
22
from mock .mock import MagicMock
23
23
import pytest
24
24
from mock import patch
25
-
25
+ from sagemaker . session_settings import SessionSettings
26
26
from sagemaker .jumpstart .cache import JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY , JumpStartModelsCache
27
27
from sagemaker .jumpstart .constants import (
28
28
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE ,
45
45
from sagemaker .jumpstart .utils import get_jumpstart_content_bucket
46
46
47
47
48
+ REGION = "us-east-1"
49
+ REGION2 = "us-east-2"
50
+ ACCOUNT_ID = "123456789123"
51
+
52
+
53
+ @pytest .fixture ()
54
+ def sagemaker_session ():
55
+ mocked_boto_session = Mock (name = "boto_session" )
56
+ mocked_s3_client = Mock (name = "s3_client" )
57
+ mocked_sagemaker_session = Mock (
58
+ name = "sagemaker_session" , boto_session = mocked_boto_session , s3_client = mocked_s3_client , boto_region_name = REGION , config = None ,
59
+ )
60
+ mocked_sagemaker_session .sagemaker_config = {}
61
+ mocked_sagemaker_session ._client_config .user_agent = (
62
+ "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource"
63
+ )
64
+ mocked_sagemaker_session .account_id .return_value = ACCOUNT_ID
65
+ return mocked_sagemaker_session
66
+
67
+
68
+
48
69
@patch .object (JumpStartModelsCache , "_retrieval_function" , patched_retrieval_function )
49
70
@patch ("sagemaker.jumpstart.utils.get_sagemaker_version" , lambda : "2.68.3" )
50
71
def test_jumpstart_cache_get_header ():
@@ -252,14 +273,14 @@ def test_jumpstart_cache_handles_boto3_issues(mock_boto3_client):
252
273
@patch ("boto3.client" )
253
274
def test_jumpstart_cache_gets_cleared_when_params_are_set (mock_boto3_client ):
254
275
cache = JumpStartModelsCache (
255
- s3_bucket_name = "some_bucket" , region = "us-west-2" , manifest_file_s3_key = "some_key"
276
+ s3_bucket_name = "some_bucket" , region = REGION , manifest_file_s3_key = "some_key"
256
277
)
257
278
258
279
cache .clear = MagicMock ()
259
280
cache .set_s3_bucket_name ("some_bucket" )
260
281
cache .clear .assert_not_called ()
261
282
cache .clear .reset_mock ()
262
- cache .set_region ("us-west-2" )
283
+ cache .set_region (REGION )
263
284
cache .clear .assert_not_called ()
264
285
cache .clear .reset_mock ()
265
286
cache .set_manifest_file_s3_key ("some_key" )
@@ -270,7 +291,7 @@ def test_jumpstart_cache_gets_cleared_when_params_are_set(mock_boto3_client):
270
291
cache .set_s3_bucket_name ("some_bucket1" )
271
292
cache .clear .assert_called_once ()
272
293
cache .clear .reset_mock ()
273
- cache .set_region ("us-east-1" )
294
+ cache .set_region (REGION2 )
274
295
cache .clear .assert_called_once ()
275
296
cache .clear .reset_mock ()
276
297
cache .set_manifest_file_s3_key ("some_key1" )
@@ -399,7 +420,6 @@ def test_jumpstart_cache_handles_boto3_client_errors():
399
420
400
421
def test_jumpstart_cache_accepts_input_parameters ():
401
422
402
- region = "us-east-1"
403
423
max_s3_cache_items = 1
404
424
s3_cache_expiration_horizon = datetime .timedelta (weeks = 2 )
405
425
max_semantic_version_cache_items = 3
@@ -408,7 +428,7 @@ def test_jumpstart_cache_accepts_input_parameters():
408
428
manifest_file_key = "some_s3_key"
409
429
410
430
cache = JumpStartModelsCache (
411
- region = region ,
431
+ region = REGION ,
412
432
max_s3_cache_items = max_s3_cache_items ,
413
433
s3_cache_expiration_horizon = s3_cache_expiration_horizon ,
414
434
max_semantic_version_cache_items = max_semantic_version_cache_items ,
@@ -418,7 +438,7 @@ def test_jumpstart_cache_accepts_input_parameters():
418
438
)
419
439
420
440
assert cache .get_manifest_file_s3_key () == manifest_file_key
421
- assert cache .get_region () == region
441
+ assert cache .get_region () == REGION
422
442
assert cache .get_bucket () == bucket
423
443
assert cache ._content_cache ._max_cache_items == max_s3_cache_items
424
444
assert cache ._content_cache ._expiration_horizon == s3_cache_expiration_horizon
@@ -741,7 +761,7 @@ def test_jumpstart_cache_get_specs():
741
761
@patch ("sagemaker.jumpstart.cache.os.path.isdir" )
742
762
@patch ("builtins.open" )
743
763
def test_jumpstart_local_metadata_override_header (
744
- mocked_open : Mock , mocked_is_dir : Mock , mocked_get_json_file_and_etag_from_s3 : Mock
764
+ mocked_open : Mock , mocked_is_dir : Mock , mocked_get_json_file_and_etag_from_s3 : Mock , sagemaker_session : Mock
745
765
):
746
766
mocked_open .side_effect = mock_open (read_data = json .dumps (BASE_MANIFEST ))
747
767
mocked_is_dir .return_value = True
@@ -760,7 +780,7 @@ def test_jumpstart_local_metadata_override_header(
760
780
mocked_is_dir .assert_any_call ("/some/directory/metadata/manifest/root" )
761
781
mocked_is_dir .assert_any_call ("/some/directory/metadata/specs/root" )
762
782
assert mocked_is_dir .call_count == 2
763
- mocked_open .assert_called_once_with (
783
+ mocked_open .assert_called_with (
764
784
"/some/directory/metadata/manifest/root/models_manifest.json" , "r"
765
785
)
766
786
mocked_get_json_file_and_etag_from_s3 .assert_not_called ()
@@ -783,6 +803,7 @@ def test_jumpstart_local_metadata_override_specs(
783
803
mocked_is_dir : Mock ,
784
804
mocked_get_json_file_and_etag_from_s3 : Mock ,
785
805
mock_emit_logs_based_on_model_specs ,
806
+ sagemaker_session ,
786
807
):
787
808
788
809
mocked_open .side_effect = [
@@ -791,7 +812,7 @@ def test_jumpstart_local_metadata_override_specs(
791
812
]
792
813
793
814
mocked_is_dir .return_value = True
794
- cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
815
+ cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" , s3_client = Mock (), sagemaker_session = sagemaker_session )
795
816
796
817
model_id , version = "tensorflow-ic-imagenet-inception-v3-classification-4" , "2.0.0"
797
818
assert JumpStartModelSpecs (BASE_SPEC ) == cache .get_specs (
@@ -845,7 +866,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
845
866
846
867
mocked_is_dir .assert_any_call ("/some/directory/metadata/manifest/root" )
847
868
assert mocked_is_dir .call_count == 2
848
- mocked_open .assert_not_called ()
869
+ assert mocked_open .call_count == 2
849
870
mocked_get_json_file_and_etag_from_s3 .assert_has_calls (
850
871
calls = [
851
872
call ("models_manifest.json" ),
0 commit comments