@@ -227,10 +227,6 @@ def test_list_jumpstart_models_simple_case(
227
227
patched_get_manifest .assert_called ()
228
228
patched_get_model_specs .assert_not_called ()
229
229
230
- @pytest .mark .skipif (
231
- datetime .datetime .now () < datetime .datetime (year = 2024 , month = 5 , day = 1 ),
232
- reason = "Contact JumpStart team to fix flaky test." ,
233
- )
234
230
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
235
231
@patch ("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file" )
236
232
def test_list_jumpstart_models_script_filter (
@@ -246,23 +242,25 @@ def test_list_jumpstart_models_script_filter(
246
242
manifest_length = len (get_prototype_manifest ())
247
243
vals = [True , False ]
248
244
for val in vals :
249
- kwargs = {"filter" : f"training_supported == { val } " }
245
+ kwargs = {"filter" : And ( f"training_supported == { val } " , "model_type is open_weights" ) }
250
246
list_jumpstart_models (** kwargs )
251
247
assert patched_read_s3_file .call_count == manifest_length
252
- patched_get_manifest .assert_called_once ()
248
+ assert patched_get_manifest .call_count == 2
253
249
254
250
patched_get_manifest .reset_mock ()
255
251
patched_read_s3_file .reset_mock ()
256
252
257
- kwargs = {"filter" : f"training_supported != { val } " }
253
+ kwargs = {"filter" : And ( f"training_supported != { val } " , "model_type is open_weights" ) }
258
254
list_jumpstart_models (** kwargs )
259
255
assert patched_read_s3_file .call_count == manifest_length
260
256
assert patched_get_manifest .call_count == 2
261
257
262
258
patched_get_manifest .reset_mock ()
263
259
patched_read_s3_file .reset_mock ()
264
-
265
- kwargs = {"filter" : f"training_supported in { vals } " , "list_versions" : True }
260
+ kwargs = {
261
+ "filter" : And (f"training_supported != { val } " , "model_type is open_weights" ),
262
+ "list_versions" : True ,
263
+ }
266
264
assert list_jumpstart_models (** kwargs ) == [
267
265
("catboost-classification-model" , "1.0.0" ),
268
266
("huggingface-spc-bert-base-cased" , "1.0.0" ),
@@ -279,7 +277,7 @@ def test_list_jumpstart_models_script_filter(
279
277
patched_get_manifest .reset_mock ()
280
278
patched_read_s3_file .reset_mock ()
281
279
282
- kwargs = {"filter" : f"training_supported not in { vals } " }
280
+ kwargs = {"filter" : And ( f"training_supported not in { vals } " , "model_type is open_weights" ) }
283
281
models = list_jumpstart_models (** kwargs )
284
282
assert [] == models
285
283
assert patched_read_s3_file .call_count == manifest_length
@@ -518,10 +516,6 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME):
518
516
list_old_models = False , list_versions = True
519
517
) == list_jumpstart_models (list_versions = True )
520
518
521
- @pytest .mark .skipif (
522
- datetime .datetime .now () < datetime .datetime (year = 2024 , month = 5 , day = 1 ),
523
- reason = "Contact JumpStart team to fix flaky test." ,
524
- )
525
519
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
526
520
@patch ("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file" )
527
521
def test_list_jumpstart_models_vulnerable_models (
@@ -547,12 +541,15 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
547
541
patched_read_s3_file .side_effect = vulnerable_inference_model_spec
548
542
549
543
num_specs = len (PROTOTYPICAL_MODEL_SPECS_DICT )
550
- num_prop_specs = len (BASE_PROPRIETARY_MANIFEST )
551
544
assert [] == list_jumpstart_models (
552
- And ("inference_vulnerable is false" , "training_vulnerable is false" )
545
+ And (
546
+ "inference_vulnerable is false" ,
547
+ "training_vulnerable is false" ,
548
+ "model_type is open_weights" ,
549
+ )
553
550
)
554
551
555
- assert patched_read_s3_file .call_count == num_specs + num_prop_specs
552
+ assert patched_read_s3_file .call_count == num_specs
556
553
assert patched_get_manifest .call_count == 2
557
554
558
555
patched_get_manifest .reset_mock ()
@@ -561,10 +558,14 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
561
558
patched_read_s3_file .side_effect = vulnerable_training_model_spec
562
559
563
560
assert [] == list_jumpstart_models (
564
- And ("inference_vulnerable is false" , "training_vulnerable is false" )
561
+ And (
562
+ "inference_vulnerable is false" ,
563
+ "training_vulnerable is false" ,
564
+ "model_type is open_weights" ,
565
+ )
565
566
)
566
567
567
- assert patched_read_s3_file .call_count == num_specs + num_prop_specs
568
+ assert patched_read_s3_file .call_count == num_specs
568
569
assert patched_get_manifest .call_count == 2
569
570
570
571
patched_get_manifest .reset_mock ()
@@ -574,10 +575,6 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
574
575
575
576
assert patched_read_s3_file .call_count == 0
576
577
577
- @pytest .mark .skipif (
578
- datetime .datetime .now () < datetime .datetime (year = 2024 , month = 5 , day = 1 ),
579
- reason = "Contact JumpStart team to fix flaky test." ,
580
- )
581
578
@patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
582
579
@patch ("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file" )
583
580
def test_list_jumpstart_models_deprecated_models (
@@ -598,10 +595,11 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str:
598
595
patched_read_s3_file .side_effect = deprecated_model_spec
599
596
600
597
num_specs = len (PROTOTYPICAL_MODEL_SPECS_DICT )
601
- num_prop_specs = len (BASE_PROPRIETARY_MANIFEST )
602
- assert [] == list_jumpstart_models ("deprecated equals false" )
598
+ assert [] == list_jumpstart_models (
599
+ And ("deprecated equals false" , "model_type is open_weights" )
600
+ )
603
601
604
- assert patched_read_s3_file .call_count == num_specs + num_prop_specs
602
+ assert patched_read_s3_file .call_count == num_specs
605
603
assert patched_get_manifest .call_count == 2
606
604
607
605
patched_get_manifest .reset_mock ()
0 commit comments