12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
14
from copy import deepcopy
15
+ import datetime
15
16
from unittest import mock
16
17
from unittest .mock import patch
17
18
import pytest
18
19
from mock import Mock
19
20
from sagemaker .jumpstart .curated_hub .curated_hub import CuratedHub
20
- from sagemaker .jumpstart .curated_hub .types import JumpStartModelInfo
21
+ from sagemaker .jumpstart .curated_hub .types import JumpStartModelInfo , S3ObjectLocation
21
22
from sagemaker .jumpstart .types import JumpStartModelSpecs
22
23
from tests .unit .sagemaker .jumpstart .constants import BASE_SPEC
23
24
from tests .unit .sagemaker .jumpstart .utils import get_spec_from_base_spec
24
25
26
+
25
27
REGION = "us-east-1"
26
28
ACCOUNT_ID = "123456789123"
27
29
HUB_NAME = "mock-hub-name"
28
30
29
31
MODULE_PATH = "sagemaker.jumpstart.curated_hub.curated_hub.CuratedHub"
30
32
33
+ FAKE_TIME = datetime .datetime (1997 , 8 , 14 , 00 , 00 , 00 )
31
34
32
35
@pytest .fixture ()
33
36
def sagemaker_session ():
@@ -39,7 +42,7 @@ def sagemaker_session():
39
42
"Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource"
40
43
)
41
44
sagemaker_session_mock .describe_hub .return_value = {
42
- "S3StorageConfig" : {"S3OutputPath" : "mock-bucket-123" }
45
+ "S3StorageConfig" : {"S3OutputPath" : "s3:// mock-bucket-123" }
43
46
}
44
47
sagemaker_session_mock .account_id .return_value = ACCOUNT_ID
45
48
return sagemaker_session_mock
@@ -66,7 +69,9 @@ def test_instantiates(sagemaker_session):
66
69
),
67
70
],
68
71
)
72
+ @patch ("sagemaker.jumpstart.curated_hub.curated_hub.CuratedHub._generate_hub_storage_location" )
69
73
def test_create_with_no_bucket_name (
74
+ mock_generate_hub_storage_location ,
70
75
sagemaker_session ,
71
76
hub_name ,
72
77
hub_description ,
@@ -75,18 +80,22 @@ def test_create_with_no_bucket_name(
75
80
hub_search_keywords ,
76
81
tags ,
77
82
):
83
+ storage_location = S3ObjectLocation ("sagemaker-hubs-us-east-1-123456789123" ,f"{ hub_name } -{ FAKE_TIME .timestamp ()} " )
84
+ mock_generate_hub_storage_location .return_value = storage_location
78
85
create_hub = {"HubArn" : f"arn:aws:sagemaker:us-east-1:123456789123:hub/{ hub_name } " }
79
86
sagemaker_session .create_hub = Mock (return_value = create_hub )
80
87
sagemaker_session .describe_hub .return_value = {
81
- "S3StorageConfig" : {"S3OutputPath" : hub_bucket_name }
88
+ "S3StorageConfig" : {"S3OutputPath" : f"s3:// { hub_bucket_name } / { storage_location . key } " }
82
89
}
83
90
hub = CuratedHub (hub_name = hub_name , sagemaker_session = sagemaker_session )
84
91
request = {
85
92
"hub_name" : hub_name ,
86
93
"hub_description" : hub_description ,
87
94
"hub_display_name" : hub_display_name ,
88
95
"hub_search_keywords" : hub_search_keywords ,
89
- "s3_storage_config" : {"S3OutputPath" : "s3://sagemaker-hubs-us-east-1-123456789123" },
96
+ "s3_storage_config" : {
97
+ "S3OutputPath" : f"s3://sagemaker-hubs-us-east-1-123456789123/{ storage_location .key } "
98
+ },
90
99
"tags" : tags ,
91
100
}
92
101
response = hub .create (
@@ -113,7 +122,9 @@ def test_create_with_no_bucket_name(
113
122
),
114
123
],
115
124
)
125
+ @patch ("sagemaker.jumpstart.curated_hub.curated_hub.CuratedHub._generate_hub_storage_location" )
116
126
def test_create_with_bucket_name (
127
+ mock_generate_hub_storage_location ,
117
128
sagemaker_session ,
118
129
hub_name ,
119
130
hub_description ,
@@ -122,6 +133,8 @@ def test_create_with_bucket_name(
122
133
hub_search_keywords ,
123
134
tags ,
124
135
):
136
+ storage_location = S3ObjectLocation (hub_bucket_name ,f"{ hub_name } -{ FAKE_TIME .timestamp ()} " )
137
+ mock_generate_hub_storage_location .return_value = storage_location
125
138
create_hub = {"HubArn" : f"arn:aws:sagemaker:us-east-1:123456789123:hub/{ hub_name } " }
126
139
sagemaker_session .create_hub = Mock (return_value = create_hub )
127
140
hub = CuratedHub (
@@ -132,7 +145,7 @@ def test_create_with_bucket_name(
132
145
"hub_description" : hub_description ,
133
146
"hub_display_name" : hub_display_name ,
134
147
"hub_search_keywords" : hub_search_keywords ,
135
- "s3_storage_config" : {"S3OutputPath" : "s3://mock-bucket-123" },
148
+ "s3_storage_config" : {"S3OutputPath" : f "s3://mock-bucket-123/ { storage_location . key } " },
136
149
"tags" : tags ,
137
150
}
138
151
response = hub .create (
@@ -397,92 +410,92 @@ def test_determine_models_to_sync(sagemaker_session):
397
410
hub_name = "mock_hub_name"
398
411
hub = CuratedHub (hub_name = hub_name , sagemaker_session = sagemaker_session )
399
412
400
- js_models_in_hub = [
401
- {
413
+ js_model_map = {
414
+ "mock-model-two-pytorch" : {
402
415
"name" : "mock-model-two-pytorch" ,
403
416
"version" : "1.0.1" ,
404
417
"search_keywords" : [
405
418
"@jumpstart-model-id:model-two-pytorch" ,
406
419
"@jumpstart-model-version:1.0.2" ,
407
420
],
408
421
},
409
- {
422
+ "mock-model-four-huggingface" : {
410
423
"name" : "mock-model-four-huggingface" ,
411
424
"version" : "2.0.2" ,
412
425
"search_keywords" : [
413
426
"@jumpstart-model-id:model-four-huggingface" ,
414
427
"@jumpstart-model-version:2.0.2" ,
415
428
],
416
429
},
417
- ]
430
+ }
418
431
model_one = JumpStartModelInfo ("mock-model-one-huggingface" , "1.2.3" )
419
432
model_two = JumpStartModelInfo ("mock-model-two-pytorch" , "1.0.2" )
420
433
# No model_one, older model_two
421
- res = hub ._determine_models_to_sync ([model_one , model_two ], js_models_in_hub )
434
+ res = hub ._determine_models_to_sync ([model_one , model_two ], js_model_map )
422
435
assert res == [model_one , model_two ]
423
436
424
- js_models_in_hub = [
425
- {
437
+ js_model_map = {
438
+ "mock-model-two-pytorch" : {
426
439
"name" : "mock-model-two-pytorch" ,
427
440
"version" : "1.0.3" ,
428
441
"search_keywords" : [
429
442
"@jumpstart-model-id:model-two-pytorch" ,
430
443
"@jumpstart-model-version:1.0.3" ,
431
444
],
432
445
},
433
- {
446
+ "mock-model-four-huggingface" : {
434
447
"name" : "mock-model-four-huggingface" ,
435
448
"version" : "2.0.2" ,
436
449
"search_keywords" : [
437
450
"@jumpstart-model-id:model-four-huggingface" ,
438
451
"@jumpstart-model-version:2.0.2" ,
439
452
],
440
453
},
441
- ]
454
+ }
442
455
# No model_one, newer model_two
443
- res = hub ._determine_models_to_sync ([model_one , model_two ], js_models_in_hub )
456
+ res = hub ._determine_models_to_sync ([model_one , model_two ], js_model_map )
444
457
assert res == [model_one ]
445
458
446
- js_models_in_hub = [
447
- {
459
+ js_model_map = {
460
+ "mock-model-one-huggingface" : {
448
461
"name" : "mock-model-one-huggingface" ,
449
462
"version" : "1.2.3" ,
450
463
"search_keywords" : [
451
464
"@jumpstart-model-id:model-one-huggingface" ,
452
465
"@jumpstart-model-version:1.2.3" ,
453
466
],
454
467
},
455
- {
468
+ "mock-model-two-pytorch" : {
456
469
"name" : "mock-model-two-pytorch" ,
457
470
"version" : "1.0.2" ,
458
471
"search_keywords" : [
459
472
"@jumpstart-model-id:model-two-pytorch" ,
460
473
"@jumpstart-model-version:1.0.2" ,
461
474
],
462
475
},
463
- ]
476
+ }
464
477
# Same model_one, same model_two
465
- res = hub ._determine_models_to_sync ([model_one , model_two ], js_models_in_hub )
478
+ res = hub ._determine_models_to_sync ([model_one , model_two ], js_model_map )
466
479
assert res == []
467
480
468
- js_models_in_hub = [
469
- {
481
+ js_model_map = {
482
+ "mock-model-one-huggingface" : {
470
483
"name" : "mock-model-one-huggingface" ,
471
484
"version" : "1.2.1" ,
472
485
"search_keywords" : [
473
486
"@jumpstart-model-id:model-one-huggingface" ,
474
487
"@jumpstart-model-version:1.2.1" ,
475
488
],
476
489
},
477
- {
490
+ "mock-model-two-pytorch" : {
478
491
"name" : "mock-model-two-pytorch" ,
479
492
"version" : "1.0.2" ,
480
493
"search_keywords" : [
481
494
"@jumpstart-model-id:model-two-pytorch" ,
482
495
"@jumpstart-model-version:1.0.2" ,
483
496
],
484
497
},
485
- ]
498
+ }
486
499
# Old model_one, same model_two
487
- res = hub ._determine_models_to_sync ([model_one , model_two ], js_models_in_hub )
500
+ res = hub ._determine_models_to_sync ([model_one , model_two ], js_model_map )
488
501
assert res == [model_one ]
0 commit comments