Skip to content

chore: update skipped flaky tests #4644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/sagemaker/jumpstart/notebook_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,12 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
return sorted(list(model_id_version_dict.keys()))

if not list_old_models:
model_id_version_dict = {
model_id: set([max(versions)]) for model_id, versions in model_id_version_dict.items()
}
for model_id, versions in model_id_version_dict.items():
try:
model_id_version_dict.update({model_id: set([max(versions)])})
except TypeError:
versions = [str(v) for v in versions]
model_id_version_dict.update({model_id: set([max(versions)])})

model_id_version_set: Set[Tuple[str, str]] = set()
for model_id in model_id_version_dict:
Expand Down
6 changes: 5 additions & 1 deletion src/sagemaker/jumpstart/payload_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
)
from sagemaker.jumpstart.enums import MIMEType
from sagemaker.jumpstart.enums import JumpStartModelType, MIMEType
from sagemaker.jumpstart.types import JumpStartSerializablePayload
from sagemaker.jumpstart.utils import (
get_jumpstart_content_bucket,
Expand Down Expand Up @@ -61,6 +61,7 @@ def _construct_payload(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
) -> Optional[JumpStartSerializablePayload]:
"""Returns example payload from prompt.

Expand All @@ -83,6 +84,8 @@ def _construct_payload(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
model_type (JumpStartModelType): The type of the model, can be open weights model or
proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS).
Returns:
Optional[JumpStartSerializablePayload]: serializable payload with prompt, or None if
this feature is unavailable for the specified model.
Expand All @@ -94,6 +97,7 @@ def _construct_payload(
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
model_type=model_type,
)
if payloads is None or len(payloads) == 0:
return None
Expand Down
52 changes: 24 additions & 28 deletions tests/unit/sagemaker/jumpstart/test_notebook_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from unittest import TestCase
from unittest.mock import Mock, patch
import datetime

import pytest
from sagemaker.jumpstart.constants import (
Expand All @@ -17,7 +16,6 @@
get_prototype_manifest,
get_prototype_model_spec,
)
from tests.unit.sagemaker.jumpstart.constants import BASE_PROPRIETARY_MANIFEST
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.jumpstart.notebook_utils import (
_generate_jumpstart_model_versions,
Expand Down Expand Up @@ -227,10 +225,6 @@ def test_list_jumpstart_models_simple_case(
patched_get_manifest.assert_called()
patched_get_model_specs.assert_not_called()

@pytest.mark.skipif(
datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1),
reason="Contact JumpStart team to fix flaky test.",
)
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
@patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file")
def test_list_jumpstart_models_script_filter(
Expand All @@ -246,23 +240,25 @@ def test_list_jumpstart_models_script_filter(
manifest_length = len(get_prototype_manifest())
vals = [True, False]
for val in vals:
kwargs = {"filter": f"training_supported == {val}"}
kwargs = {"filter": And(f"training_supported == {val}", "model_type is open_weights")}
list_jumpstart_models(**kwargs)
assert patched_read_s3_file.call_count == manifest_length
patched_get_manifest.assert_called_once()
assert patched_get_manifest.call_count == 2

patched_get_manifest.reset_mock()
patched_read_s3_file.reset_mock()

kwargs = {"filter": f"training_supported != {val}"}
kwargs = {"filter": And(f"training_supported != {val}", "model_type is open_weights")}
list_jumpstart_models(**kwargs)
assert patched_read_s3_file.call_count == manifest_length
assert patched_get_manifest.call_count == 2

patched_get_manifest.reset_mock()
patched_read_s3_file.reset_mock()

kwargs = {"filter": f"training_supported in {vals}", "list_versions": True}
kwargs = {
"filter": And(f"training_supported != {val}", "model_type is open_weights"),
"list_versions": True,
}
assert list_jumpstart_models(**kwargs) == [
("catboost-classification-model", "1.0.0"),
("huggingface-spc-bert-base-cased", "1.0.0"),
Expand All @@ -279,7 +275,7 @@ def test_list_jumpstart_models_script_filter(
patched_get_manifest.reset_mock()
patched_read_s3_file.reset_mock()

kwargs = {"filter": f"training_supported not in {vals}"}
kwargs = {"filter": And(f"training_supported not in {vals}", "model_type is open_weights")}
models = list_jumpstart_models(**kwargs)
assert [] == models
assert patched_read_s3_file.call_count == manifest_length
Expand Down Expand Up @@ -518,10 +514,6 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME):
list_old_models=False, list_versions=True
) == list_jumpstart_models(list_versions=True)

@pytest.mark.skipif(
datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1),
reason="Contact JumpStart team to fix flaky test.",
)
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
@patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file")
def test_list_jumpstart_models_vulnerable_models(
Expand All @@ -547,12 +539,15 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
patched_read_s3_file.side_effect = vulnerable_inference_model_spec

num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
num_prop_specs = len(BASE_PROPRIETARY_MANIFEST)
assert [] == list_jumpstart_models(
And("inference_vulnerable is false", "training_vulnerable is false")
And(
"inference_vulnerable is false",
"training_vulnerable is false",
"model_type is open_weights",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to include "model_type is open_weights" or is it still included by default?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not included by default by design, list_jumpstart_models has both open weights and proprietary models for discoverability, and "model_type is open_weights" filter is used to only list open weights models.

)
)

assert patched_read_s3_file.call_count == num_specs + num_prop_specs
assert patched_read_s3_file.call_count == num_specs
assert patched_get_manifest.call_count == 2

patched_get_manifest.reset_mock()
Expand All @@ -561,10 +556,14 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
patched_read_s3_file.side_effect = vulnerable_training_model_spec

assert [] == list_jumpstart_models(
And("inference_vulnerable is false", "training_vulnerable is false")
And(
"inference_vulnerable is false",
"training_vulnerable is false",
"model_type is open_weights",
)
)

assert patched_read_s3_file.call_count == num_specs + num_prop_specs
assert patched_read_s3_file.call_count == num_specs
assert patched_get_manifest.call_count == 2

patched_get_manifest.reset_mock()
Expand All @@ -574,10 +573,6 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):

assert patched_read_s3_file.call_count == 0

@pytest.mark.skipif(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm still confused why this was flakey in the first place and how your PR fixed this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was flaky before we introduced prop models, but somehow seen failing more often after prop models. I updated to only list open weight models, and tried starting the test in multiple threads in parallel locally and they haven't been failing once for me. I think there is less parallelism in PR hooks now, so hope this is not flaky anymore.

datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1),
reason="Contact JumpStart team to fix flaky test.",
)
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
@patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file")
def test_list_jumpstart_models_deprecated_models(
Expand All @@ -598,10 +593,11 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str:
patched_read_s3_file.side_effect = deprecated_model_spec

num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT)
num_prop_specs = len(BASE_PROPRIETARY_MANIFEST)
assert [] == list_jumpstart_models("deprecated equals false")
assert [] == list_jumpstart_models(
And("deprecated equals false", "model_type is open_weights")
)

assert patched_read_s3_file.call_count == num_specs + num_prop_specs
assert patched_read_s3_file.call_count == num_specs
assert patched_get_manifest.call_count == 2

patched_get_manifest.reset_mock()
Expand Down