Skip to content

Commit 9fef8e0

Browse files
committed
Update skipped flaky tests
1 parent 7c49f5d commit 9fef8e0

File tree

3 files changed

+35
-30
lines changed

3 files changed

+35
-30
lines changed

src/sagemaker/jumpstart/notebook_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,12 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
329329
return sorted(list(model_id_version_dict.keys()))
330330

331331
if not list_old_models:
332-
model_id_version_dict = {
333-
model_id: set([max(versions)]) for model_id, versions in model_id_version_dict.items()
334-
}
332+
for model_id, versions in model_id_version_dict.items():
333+
try:
334+
model_id_version_dict.update({model_id: set([max(versions)])})
335+
except TypeError:
336+
versions = [str(v) for v in versions]
337+
model_id_version_dict.update({model_id: set([max(versions)])})
335338

336339
model_id_version_set: Set[Tuple[str, str]] = set()
337340
for model_id in model_id_version_dict:

src/sagemaker/jumpstart/payload_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from sagemaker.jumpstart.constants import (
2424
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2525
)
26-
from sagemaker.jumpstart.enums import MIMEType
26+
from sagemaker.jumpstart.enums import JumpStartModelType, MIMEType
2727
from sagemaker.jumpstart.types import JumpStartSerializablePayload
2828
from sagemaker.jumpstart.utils import (
2929
get_jumpstart_content_bucket,
@@ -61,6 +61,7 @@ def _construct_payload(
6161
tolerate_vulnerable_model: bool = False,
6262
tolerate_deprecated_model: bool = False,
6363
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
64+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
6465
) -> Optional[JumpStartSerializablePayload]:
6566
"""Returns example payload from prompt.
6667
@@ -83,6 +84,8 @@ def _construct_payload(
8384
object, used for SageMaker interactions. If not
8485
specified, one is created using the default AWS configuration
8586
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
87+
model_type (JumpStartModelType): The type of the model, can be open weights model or
88+
proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
8689
Returns:
8790
Optional[JumpStartSerializablePayload]: serializable payload with prompt, or None if
8891
this feature is unavailable for the specified model.
@@ -94,6 +97,7 @@ def _construct_payload(
9497
tolerate_vulnerable_model=tolerate_vulnerable_model,
9598
tolerate_deprecated_model=tolerate_deprecated_model,
9699
sagemaker_session=sagemaker_session,
100+
model_type=model_type,
97101
)
98102
if payloads is None or len(payloads) == 0:
99103
return None

tests/unit/sagemaker/jumpstart/test_notebook_utils.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,6 @@ def test_list_jumpstart_models_simple_case(
227227
patched_get_manifest.assert_called()
228228
patched_get_model_specs.assert_not_called()
229229

230-
@pytest.mark.skipif(
231-
datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1),
232-
reason="Contact JumpStart team to fix flaky test.",
233-
)
234230
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
235231
@patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file")
236232
def test_list_jumpstart_models_script_filter(
@@ -246,23 +242,25 @@ def test_list_jumpstart_models_script_filter(
246242
manifest_length = len(get_prototype_manifest())
247243
vals = [True, False]
248244
for val in vals:
249-
kwargs = {"filter": f"training_supported == {val}"}
245+
kwargs = {"filter": And(f"training_supported == {val}", "model_type is open_weights")}
250246
list_jumpstart_models(**kwargs)
251247
assert patched_read_s3_file.call_count == manifest_length
252-
patched_get_manifest.assert_called_once()
248+
assert patched_get_manifest.call_count == 2
253249

254250
patched_get_manifest.reset_mock()
255251
patched_read_s3_file.reset_mock()
256252

257-
kwargs = {"filter": f"training_supported != {val}"}
253+
kwargs = {"filter": And(f"training_supported != {val}", "model_type is open_weights")}
258254
list_jumpstart_models(**kwargs)
259255
assert patched_read_s3_file.call_count == manifest_length
260256
assert patched_get_manifest.call_count == 2
261257

262258
patched_get_manifest.reset_mock()
263259
patched_read_s3_file.reset_mock()
264-
265-
kwargs = {"filter": f"training_supported in {vals}", "list_versions": True}
260+
kwargs = {
261+
"filter": And(f"training_supported != {val}", "model_type is open_weights"),
262+
"list_versions": True,
263+
}
266264
assert list_jumpstart_models(**kwargs) == [
267265
("catboost-classification-model", "1.0.0"),
268266
("huggingface-spc-bert-base-cased", "1.0.0"),
@@ -279,7 +277,7 @@ def test_list_jumpstart_models_script_filter(
279277
patched_get_manifest.reset_mock()
280278
patched_read_s3_file.reset_mock()
281279

282-
kwargs = {"filter": f"training_supported not in {vals}"}
280+
kwargs = {"filter": And(f"training_supported not in {vals}", "model_type is open_weights")}
283281
models = list_jumpstart_models(**kwargs)
284282
assert [] == models
285283
assert patched_read_s3_file.call_count == manifest_length
@@ -518,10 +516,6 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME):
518516
list_old_models=False, list_versions=True
519517
) == list_jumpstart_models(list_versions=True)
520518

521-
@pytest.mark.skipif(
522-
datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1),
523-
reason="Contact JumpStart team to fix flaky test.",
524-
)
525519
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
526520
@patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file")
527521
def test_list_jumpstart_models_vulnerable_models(
@@ -547,12 +541,15 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
547541
patched_read_s3_file.side_effect = vulnerable_inference_model_spec
548542

549543
num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
550-
num_prop_specs = len(BASE_PROPRIETARY_MANIFEST)
551544
assert [] == list_jumpstart_models(
552-
And("inference_vulnerable is false", "training_vulnerable is false")
545+
And(
546+
"inference_vulnerable is false",
547+
"training_vulnerable is false",
548+
"model_type is open_weights",
549+
)
553550
)
554551

555-
assert patched_read_s3_file.call_count == num_specs + num_prop_specs
552+
assert patched_read_s3_file.call_count == num_specs
556553
assert patched_get_manifest.call_count == 2
557554

558555
patched_get_manifest.reset_mock()
@@ -561,10 +558,14 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
561558
patched_read_s3_file.side_effect = vulnerable_training_model_spec
562559

563560
assert [] == list_jumpstart_models(
564-
And("inference_vulnerable is false", "training_vulnerable is false")
561+
And(
562+
"inference_vulnerable is false",
563+
"training_vulnerable is false",
564+
"model_type is open_weights",
565+
)
565566
)
566567

567-
assert patched_read_s3_file.call_count == num_specs + num_prop_specs
568+
assert patched_read_s3_file.call_count == num_specs
568569
assert patched_get_manifest.call_count == 2
569570

570571
patched_get_manifest.reset_mock()
@@ -574,10 +575,6 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
574575

575576
assert patched_read_s3_file.call_count == 0
576577

577-
@pytest.mark.skipif(
578-
datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1),
579-
reason="Contact JumpStart team to fix flaky test.",
580-
)
581578
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
582579
@patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file")
583580
def test_list_jumpstart_models_deprecated_models(
@@ -598,10 +595,11 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str:
598595
patched_read_s3_file.side_effect = deprecated_model_spec
599596

600597
num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
601-
num_prop_specs = len(BASE_PROPRIETARY_MANIFEST)
602-
assert [] == list_jumpstart_models("deprecated equals false")
598+
assert [] == list_jumpstart_models(
599+
And("deprecated equals false", "model_type is open_weights")
600+
)
603601

604-
assert patched_read_s3_file.call_count == num_specs + num_prop_specs
602+
assert patched_read_s3_file.call_count == num_specs
605603
assert patched_get_manifest.call_count == 2
606604

607605
patched_get_manifest.reset_mock()

0 commit comments

Comments
 (0)