53
53
@pytest .fixture ()
54
54
def sagemaker_session ():
55
55
mocked_boto_session = Mock (name = "boto_session" )
56
- mocked_s3_client = Mock (name = "s3_client" )
56
+ mocked_s3_client = Mock (name = "s3_client" )
57
57
mocked_sagemaker_session = Mock (
58
- name = "sagemaker_session" , boto_session = mocked_boto_session , s3_client = mocked_s3_client , boto_region_name = REGION , config = None ,
58
+ name = "sagemaker_session" ,
59
+ boto_session = mocked_boto_session ,
60
+ s3_client = mocked_s3_client ,
61
+ boto_region_name = REGION ,
62
+ config = None ,
59
63
)
60
64
mocked_sagemaker_session .sagemaker_config = {}
61
65
mocked_sagemaker_session ._client_config .user_agent = (
@@ -65,7 +69,6 @@ def sagemaker_session():
65
69
return mocked_sagemaker_session
66
70
67
71
68
-
69
72
@patch .object (JumpStartModelsCache , "_retrieval_function" , patched_retrieval_function )
70
73
@patch ("sagemaker.jumpstart.utils.get_sagemaker_version" , lambda : "2.68.3" )
71
74
def test_jumpstart_cache_get_header ():
@@ -761,7 +764,10 @@ def test_jumpstart_cache_get_specs():
761
764
@patch ("sagemaker.jumpstart.cache.os.path.isdir" )
762
765
@patch ("builtins.open" )
763
766
def test_jumpstart_local_metadata_override_header (
764
- mocked_open : Mock , mocked_is_dir : Mock , mocked_get_json_file_and_etag_from_s3 : Mock , sagemaker_session : Mock
767
+ mocked_open : Mock ,
768
+ mocked_is_dir : Mock ,
769
+ mocked_get_json_file_and_etag_from_s3 : Mock ,
770
+ sagemaker_session : Mock ,
765
771
):
766
772
mocked_open .side_effect = mock_open (read_data = json .dumps (BASE_MANIFEST ))
767
773
mocked_is_dir .return_value = True
@@ -812,7 +818,9 @@ def test_jumpstart_local_metadata_override_specs(
812
818
]
813
819
814
820
mocked_is_dir .return_value = True
815
- cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" , s3_client = Mock (), sagemaker_session = sagemaker_session )
821
+ cache = JumpStartModelsCache (
822
+ s3_bucket_name = "some_bucket" , s3_client = Mock (), sagemaker_session = sagemaker_session
823
+ )
816
824
817
825
model_id , version = "tensorflow-ic-imagenet-inception-v3-classification-4" , "2.0.0"
818
826
assert JumpStartModelSpecs (BASE_SPEC ) == cache .get_specs (
0 commit comments