Skip to content

Commit 593532b

Browse files
Captainiaroot
authored and
root
committed
chore: update skipped flaky tests (aws#4644)
* Update skipped flaky tests * flake8 * format * format
1 parent 837baa8 commit 593532b

File tree

3 files changed

+35
-32
lines changed

3 files changed

+35
-32
lines changed

src/sagemaker/jumpstart/notebook_utils.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,12 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
329329
return sorted(list(model_id_version_dict.keys()))
330330

331331
if not list_old_models:
332-
model_id_version_dict = {
333-
model_id: set([max(versions)]) for model_id, versions in model_id_version_dict.items()
334-
}
332+
for model_id, versions in model_id_version_dict.items():
333+
try:
334+
model_id_version_dict.update({model_id: set([max(versions)])})
335+
except TypeError:
336+
versions = [str(v) for v in versions]
337+
model_id_version_dict.update({model_id: set([max(versions)])})
335338

336339
model_id_version_set: Set[Tuple[str, str]] = set()
337340
for model_id in model_id_version_dict:

src/sagemaker/jumpstart/payload_utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from sagemaker.jumpstart.constants import (
2424
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2525
)
26-
from sagemaker.jumpstart.enums import MIMEType
26+
from sagemaker.jumpstart.enums import JumpStartModelType, MIMEType
2727
from sagemaker.jumpstart.types import JumpStartSerializablePayload
2828
from sagemaker.jumpstart.utils import (
2929
get_jumpstart_content_bucket,
@@ -61,6 +61,7 @@ def _construct_payload(
6161
tolerate_vulnerable_model: bool = False,
6262
tolerate_deprecated_model: bool = False,
6363
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
64+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
6465
) -> Optional[JumpStartSerializablePayload]:
6566
"""Returns example payload from prompt.
6667
@@ -83,6 +84,8 @@ def _construct_payload(
8384
object, used for SageMaker interactions. If not
8485
specified, one is created using the default AWS configuration
8586
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
87+
model_type (JumpStartModelType): The type of the model, can be open weights model or
88+
proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
8689
Returns:
8790
Optional[JumpStartSerializablePayload]: serializable payload with prompt, or None if
8891
this feature is unavailable for the specified model.
@@ -94,6 +97,7 @@ def _construct_payload(
9497
tolerate_vulnerable_model=tolerate_vulnerable_model,
9598
tolerate_deprecated_model=tolerate_deprecated_model,
9699
sagemaker_session=sagemaker_session,
100+
model_type=model_type,
97101
)
98102
if payloads is None or len(payloads) == 0:
99103
return None

tests/unit/sagemaker/jumpstart/test_notebook_utils.py

+24-28
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from unittest import TestCase
55
from unittest.mock import Mock, patch
6-
import datetime
76

87
import pytest
98
from sagemaker.jumpstart.constants import (
@@ -17,7 +16,6 @@
1716
get_prototype_manifest,
1817
get_prototype_model_spec,
1918
)
20-
from tests.unit.sagemaker.jumpstart.constants import BASE_PROPRIETARY_MANIFEST
2119
from sagemaker.jumpstart.enums import JumpStartModelType
2220
from sagemaker.jumpstart.notebook_utils import (
2321
_generate_jumpstart_model_versions,
@@ -227,10 +225,6 @@ def test_list_jumpstart_models_simple_case(
227225
patched_get_manifest.assert_called()
228226
patched_get_model_specs.assert_not_called()
229227

230-
@pytest.mark.skipif(
231-
datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1),
232-
reason="Contact JumpStart team to fix flaky test.",
233-
)
234228
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
235229
@patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file")
236230
def test_list_jumpstart_models_script_filter(
@@ -246,23 +240,25 @@ def test_list_jumpstart_models_script_filter(
246240
manifest_length = len(get_prototype_manifest())
247241
vals = [True, False]
248242
for val in vals:
249-
kwargs = {"filter": f"training_supported == {val}"}
243+
kwargs = {"filter": And(f"training_supported == {val}", "model_type is open_weights")}
250244
list_jumpstart_models(**kwargs)
251245
assert patched_read_s3_file.call_count == manifest_length
252-
patched_get_manifest.assert_called_once()
246+
assert patched_get_manifest.call_count == 2
253247

254248
patched_get_manifest.reset_mock()
255249
patched_read_s3_file.reset_mock()
256250

257-
kwargs = {"filter": f"training_supported != {val}"}
251+
kwargs = {"filter": And(f"training_supported != {val}", "model_type is open_weights")}
258252
list_jumpstart_models(**kwargs)
259253
assert patched_read_s3_file.call_count == manifest_length
260254
assert patched_get_manifest.call_count == 2
261255

262256
patched_get_manifest.reset_mock()
263257
patched_read_s3_file.reset_mock()
264-
265-
kwargs = {"filter": f"training_supported in {vals}", "list_versions": True}
258+
kwargs = {
259+
"filter": And(f"training_supported != {val}", "model_type is open_weights"),
260+
"list_versions": True,
261+
}
266262
assert list_jumpstart_models(**kwargs) == [
267263
("catboost-classification-model", "1.0.0"),
268264
("huggingface-spc-bert-base-cased", "1.0.0"),
@@ -279,7 +275,7 @@ def test_list_jumpstart_models_script_filter(
279275
patched_get_manifest.reset_mock()
280276
patched_read_s3_file.reset_mock()
281277

282-
kwargs = {"filter": f"training_supported not in {vals}"}
278+
kwargs = {"filter": And(f"training_supported not in {vals}", "model_type is open_weights")}
283279
models = list_jumpstart_models(**kwargs)
284280
assert [] == models
285281
assert patched_read_s3_file.call_count == manifest_length
@@ -518,10 +514,6 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME):
518514
list_old_models=False, list_versions=True
519515
) == list_jumpstart_models(list_versions=True)
520516

521-
@pytest.mark.skipif(
522-
datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1),
523-
reason="Contact JumpStart team to fix flaky test.",
524-
)
525517
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
526518
@patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file")
527519
def test_list_jumpstart_models_vulnerable_models(
@@ -547,12 +539,15 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
547539
patched_read_s3_file.side_effect = vulnerable_inference_model_spec
548540

549541
num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
550-
num_prop_specs = len(BASE_PROPRIETARY_MANIFEST)
551542
assert [] == list_jumpstart_models(
552-
And("inference_vulnerable is false", "training_vulnerable is false")
543+
And(
544+
"inference_vulnerable is false",
545+
"training_vulnerable is false",
546+
"model_type is open_weights",
547+
)
553548
)
554549

555-
assert patched_read_s3_file.call_count == num_specs + num_prop_specs
550+
assert patched_read_s3_file.call_count == num_specs
556551
assert patched_get_manifest.call_count == 2
557552

558553
patched_get_manifest.reset_mock()
@@ -561,10 +556,14 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
561556
patched_read_s3_file.side_effect = vulnerable_training_model_spec
562557

563558
assert [] == list_jumpstart_models(
564-
And("inference_vulnerable is false", "training_vulnerable is false")
559+
And(
560+
"inference_vulnerable is false",
561+
"training_vulnerable is false",
562+
"model_type is open_weights",
563+
)
565564
)
566565

567-
assert patched_read_s3_file.call_count == num_specs + num_prop_specs
566+
assert patched_read_s3_file.call_count == num_specs
568567
assert patched_get_manifest.call_count == 2
569568

570569
patched_get_manifest.reset_mock()
@@ -574,10 +573,6 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
574573

575574
assert patched_read_s3_file.call_count == 0
576575

577-
@pytest.mark.skipif(
578-
datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1),
579-
reason="Contact JumpStart team to fix flaky test.",
580-
)
581576
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
582577
@patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file")
583578
def test_list_jumpstart_models_deprecated_models(
@@ -598,10 +593,11 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str:
598593
patched_read_s3_file.side_effect = deprecated_model_spec
599594

600595
num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
601-
num_prop_specs = len(BASE_PROPRIETARY_MANIFEST)
602-
assert [] == list_jumpstart_models("deprecated equals false")
596+
assert [] == list_jumpstart_models(
597+
And("deprecated equals false", "model_type is open_weights")
598+
)
603599

604-
assert patched_read_s3_file.call_count == num_specs + num_prop_specs
600+
assert patched_read_s3_file.call_count == num_specs
605601
assert patched_get_manifest.call_count == 2
606602

607603
patched_get_manifest.reset_mock()

0 commit comments

Comments
 (0)