3
3
4
4
from unittest import TestCase
5
5
from unittest .mock import Mock , patch
6
- import datetime
7
6
8
7
import pytest
9
8
from sagemaker .jumpstart .constants import (
17
16
get_prototype_manifest ,
18
17
get_prototype_model_spec ,
19
18
)
20
- from tests .unit .sagemaker .jumpstart .constants import BASE_PROPRIETARY_MANIFEST
21
19
from sagemaker .jumpstart .enums import JumpStartModelType
22
20
from sagemaker .jumpstart .notebook_utils import (
23
21
_generate_jumpstart_model_versions ,
@@ -227,10 +225,6 @@ def test_list_jumpstart_models_simple_case(
227
225
patched_get_manifest .assert_called ()
228
226
patched_get_model_specs .assert_not_called ()
229
227
230
- @pytest .mark .skipif (
231
- datetime .datetime .now () < datetime .datetime (year = 2024 , month = 5 , day = 1 ),
232
- reason = "Contact JumpStart team to fix flaky test." ,
233
- )
234
228
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
235
229
@patch ("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file" )
236
230
def test_list_jumpstart_models_script_filter (
@@ -246,23 +240,25 @@ def test_list_jumpstart_models_script_filter(
246
240
manifest_length = len (get_prototype_manifest ())
247
241
vals = [True , False ]
248
242
for val in vals :
249
- kwargs = {"filter" : f"training_supported == { val } " }
243
+ kwargs = {"filter" : And ( f"training_supported == { val } " , "model_type is open_weights" ) }
250
244
list_jumpstart_models (** kwargs )
251
245
assert patched_read_s3_file .call_count == manifest_length
252
- patched_get_manifest .assert_called_once ()
246
+ assert patched_get_manifest .call_count == 2
253
247
254
248
patched_get_manifest .reset_mock ()
255
249
patched_read_s3_file .reset_mock ()
256
250
257
- kwargs = {"filter" : f"training_supported != { val } " }
251
+ kwargs = {"filter" : And ( f"training_supported != { val } " , "model_type is open_weights" ) }
258
252
list_jumpstart_models (** kwargs )
259
253
assert patched_read_s3_file .call_count == manifest_length
260
254
assert patched_get_manifest .call_count == 2
261
255
262
256
patched_get_manifest .reset_mock ()
263
257
patched_read_s3_file .reset_mock ()
264
-
265
- kwargs = {"filter" : f"training_supported in { vals } " , "list_versions" : True }
258
+ kwargs = {
259
+ "filter" : And (f"training_supported != { val } " , "model_type is open_weights" ),
260
+ "list_versions" : True ,
261
+ }
266
262
assert list_jumpstart_models (** kwargs ) == [
267
263
("catboost-classification-model" , "1.0.0" ),
268
264
("huggingface-spc-bert-base-cased" , "1.0.0" ),
@@ -279,7 +275,7 @@ def test_list_jumpstart_models_script_filter(
279
275
patched_get_manifest .reset_mock ()
280
276
patched_read_s3_file .reset_mock ()
281
277
282
- kwargs = {"filter" : f"training_supported not in { vals } " }
278
+ kwargs = {"filter" : And ( f"training_supported not in { vals } " , "model_type is open_weights" ) }
283
279
models = list_jumpstart_models (** kwargs )
284
280
assert [] == models
285
281
assert patched_read_s3_file .call_count == manifest_length
@@ -518,10 +514,6 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME):
518
514
list_old_models = False , list_versions = True
519
515
) == list_jumpstart_models (list_versions = True )
520
516
521
- @pytest .mark .skipif (
522
- datetime .datetime .now () < datetime .datetime (year = 2024 , month = 5 , day = 1 ),
523
- reason = "Contact JumpStart team to fix flaky test." ,
524
- )
525
517
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
526
518
@patch ("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file" )
527
519
def test_list_jumpstart_models_vulnerable_models (
@@ -547,12 +539,15 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
547
539
patched_read_s3_file .side_effect = vulnerable_inference_model_spec
548
540
549
541
num_specs = len (PROTOTYPICAL_MODEL_SPECS_DICT )
550
- num_prop_specs = len (BASE_PROPRIETARY_MANIFEST )
551
542
assert [] == list_jumpstart_models (
552
- And ("inference_vulnerable is false" , "training_vulnerable is false" )
543
+ And (
544
+ "inference_vulnerable is false" ,
545
+ "training_vulnerable is false" ,
546
+ "model_type is open_weights" ,
547
+ )
553
548
)
554
549
555
- assert patched_read_s3_file .call_count == num_specs + num_prop_specs
550
+ assert patched_read_s3_file .call_count == num_specs
556
551
assert patched_get_manifest .call_count == 2
557
552
558
553
patched_get_manifest .reset_mock ()
@@ -561,10 +556,14 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
561
556
patched_read_s3_file .side_effect = vulnerable_training_model_spec
562
557
563
558
assert [] == list_jumpstart_models (
564
- And ("inference_vulnerable is false" , "training_vulnerable is false" )
559
+ And (
560
+ "inference_vulnerable is false" ,
561
+ "training_vulnerable is false" ,
562
+ "model_type is open_weights" ,
563
+ )
565
564
)
566
565
567
- assert patched_read_s3_file .call_count == num_specs + num_prop_specs
566
+ assert patched_read_s3_file .call_count == num_specs
568
567
assert patched_get_manifest .call_count == 2
569
568
570
569
patched_get_manifest .reset_mock ()
@@ -574,10 +573,6 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
574
573
575
574
assert patched_read_s3_file .call_count == 0
576
575
577
- @pytest .mark .skipif (
578
- datetime .datetime .now () < datetime .datetime (year = 2024 , month = 5 , day = 1 ),
579
- reason = "Contact JumpStart team to fix flaky test." ,
580
- )
581
576
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
582
577
@patch ("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file" )
583
578
def test_list_jumpstart_models_deprecated_models (
@@ -598,10 +593,11 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str:
598
593
patched_read_s3_file .side_effect = deprecated_model_spec
599
594
600
595
num_specs = len (PROTOTYPICAL_MODEL_SPECS_DICT )
601
- num_prop_specs = len (BASE_PROPRIETARY_MANIFEST )
602
- assert [] == list_jumpstart_models ("deprecated equals false" )
596
+ assert [] == list_jumpstart_models (
597
+ And ("deprecated equals false" , "model_type is open_weights" )
598
+ )
603
599
604
- assert patched_read_s3_file .call_count == num_specs + num_prop_specs
600
+ assert patched_read_s3_file .call_count == num_specs
605
601
assert patched_get_manifest .call_count == 2
606
602
607
603
patched_get_manifest .reset_mock ()
0 commit comments