-
Notifications
You must be signed in to change notification settings - Fork 1.2k
chore: update skipped flaky tests #4644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,6 @@ | |
|
||
from unittest import TestCase | ||
from unittest.mock import Mock, patch | ||
import datetime | ||
|
||
import pytest | ||
from sagemaker.jumpstart.constants import ( | ||
|
@@ -17,7 +16,6 @@ | |
get_prototype_manifest, | ||
get_prototype_model_spec, | ||
) | ||
from tests.unit.sagemaker.jumpstart.constants import BASE_PROPRIETARY_MANIFEST | ||
from sagemaker.jumpstart.enums import JumpStartModelType | ||
from sagemaker.jumpstart.notebook_utils import ( | ||
_generate_jumpstart_model_versions, | ||
|
@@ -227,10 +225,6 @@ def test_list_jumpstart_models_simple_case( | |
patched_get_manifest.assert_called() | ||
patched_get_model_specs.assert_not_called() | ||
|
||
@pytest.mark.skipif( | ||
datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1), | ||
reason="Contact JumpStart team to fix flaky test.", | ||
) | ||
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") | ||
@patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") | ||
def test_list_jumpstart_models_script_filter( | ||
|
@@ -246,23 +240,25 @@ def test_list_jumpstart_models_script_filter( | |
manifest_length = len(get_prototype_manifest()) | ||
vals = [True, False] | ||
for val in vals: | ||
kwargs = {"filter": f"training_supported == {val}"} | ||
kwargs = {"filter": And(f"training_supported == {val}", "model_type is open_weights")} | ||
list_jumpstart_models(**kwargs) | ||
assert patched_read_s3_file.call_count == manifest_length | ||
patched_get_manifest.assert_called_once() | ||
assert patched_get_manifest.call_count == 2 | ||
|
||
patched_get_manifest.reset_mock() | ||
patched_read_s3_file.reset_mock() | ||
|
||
kwargs = {"filter": f"training_supported != {val}"} | ||
kwargs = {"filter": And(f"training_supported != {val}", "model_type is open_weights")} | ||
list_jumpstart_models(**kwargs) | ||
assert patched_read_s3_file.call_count == manifest_length | ||
assert patched_get_manifest.call_count == 2 | ||
|
||
patched_get_manifest.reset_mock() | ||
patched_read_s3_file.reset_mock() | ||
|
||
kwargs = {"filter": f"training_supported in {vals}", "list_versions": True} | ||
kwargs = { | ||
"filter": And(f"training_supported != {val}", "model_type is open_weights"), | ||
"list_versions": True, | ||
} | ||
assert list_jumpstart_models(**kwargs) == [ | ||
("catboost-classification-model", "1.0.0"), | ||
("huggingface-spc-bert-base-cased", "1.0.0"), | ||
|
@@ -279,7 +275,7 @@ def test_list_jumpstart_models_script_filter( | |
patched_get_manifest.reset_mock() | ||
patched_read_s3_file.reset_mock() | ||
|
||
kwargs = {"filter": f"training_supported not in {vals}"} | ||
kwargs = {"filter": And(f"training_supported not in {vals}", "model_type is open_weights")} | ||
models = list_jumpstart_models(**kwargs) | ||
assert [] == models | ||
assert patched_read_s3_file.call_count == manifest_length | ||
|
@@ -518,10 +514,6 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME): | |
list_old_models=False, list_versions=True | ||
) == list_jumpstart_models(list_versions=True) | ||
|
||
@pytest.mark.skipif( | ||
datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1), | ||
reason="Contact JumpStart team to fix flaky test.", | ||
) | ||
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") | ||
@patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") | ||
def test_list_jumpstart_models_vulnerable_models( | ||
|
@@ -547,12 +539,15 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): | |
patched_read_s3_file.side_effect = vulnerable_inference_model_spec | ||
|
||
num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) | ||
num_prop_specs = len(BASE_PROPRIETARY_MANIFEST) | ||
assert [] == list_jumpstart_models( | ||
And("inference_vulnerable is false", "training_vulnerable is false") | ||
And( | ||
"inference_vulnerable is false", | ||
"training_vulnerable is false", | ||
"model_type is open_weights", | ||
) | ||
) | ||
|
||
assert patched_read_s3_file.call_count == num_specs + num_prop_specs | ||
assert patched_read_s3_file.call_count == num_specs | ||
assert patched_get_manifest.call_count == 2 | ||
|
||
patched_get_manifest.reset_mock() | ||
|
@@ -561,10 +556,14 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): | |
patched_read_s3_file.side_effect = vulnerable_training_model_spec | ||
|
||
assert [] == list_jumpstart_models( | ||
And("inference_vulnerable is false", "training_vulnerable is false") | ||
And( | ||
"inference_vulnerable is false", | ||
"training_vulnerable is false", | ||
"model_type is open_weights", | ||
) | ||
) | ||
|
||
assert patched_read_s3_file.call_count == num_specs + num_prop_specs | ||
assert patched_read_s3_file.call_count == num_specs | ||
assert patched_get_manifest.call_count == 2 | ||
|
||
patched_get_manifest.reset_mock() | ||
|
@@ -574,10 +573,6 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): | |
|
||
assert patched_read_s3_file.call_count == 0 | ||
|
||
@pytest.mark.skipif( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i'm still confused why this was flakey in the first place and how your PR fixed this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this was flaky before we introduced prop models, but somehow seen failing more often after prop models. I updated to only list open weight models, and tried starting the test in multiple threads in parallel locally and they haven't been failing once for me. I think there is less parallelism in PR hooks now, so hope this is not flaky anymore. |
||
datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1), | ||
reason="Contact JumpStart team to fix flaky test.", | ||
) | ||
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") | ||
@patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") | ||
def test_list_jumpstart_models_deprecated_models( | ||
|
@@ -598,10 +593,11 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str: | |
patched_read_s3_file.side_effect = deprecated_model_spec | ||
|
||
num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) | ||
num_prop_specs = len(BASE_PROPRIETARY_MANIFEST) | ||
assert [] == list_jumpstart_models("deprecated equals false") | ||
assert [] == list_jumpstart_models( | ||
And("deprecated equals false", "model_type is open_weights") | ||
) | ||
|
||
assert patched_read_s3_file.call_count == num_specs + num_prop_specs | ||
assert patched_read_s3_file.call_count == num_specs | ||
assert patched_get_manifest.call_count == 2 | ||
|
||
patched_get_manifest.reset_mock() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to include "model_type is open_weights" or is it still included by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not included by default by design,
list_jumpstart_models
has both open weights and proprietary models for discoverability, and"model_type is open_weights"
filter is used to only list open weights models.