Skip to content

Commit 7860f16

Browse files
authored
fix: Chore/reset cache if js model not found (#3945)
1 parent 6a7002e commit 7860f16

File tree

4 files changed

+183
-14
lines changed

4 files changed

+183
-14
lines changed

src/sagemaker/jumpstart/estimator.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from sagemaker.explainer.explainer_config import ExplainerConfig
2828
from sagemaker.inputs import FileSystemInput, TrainingInput
2929
from sagemaker.instance_group import InstanceGroup
30+
from sagemaker.jumpstart.accessors import JumpStartModelsAccessor
3031
from sagemaker.jumpstart.enums import JumpStartScriptScope
3132
from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG
3233

@@ -488,13 +489,18 @@ def __init__(
488489
ValueError: If the model ID is not recognized by JumpStart.
489490
"""
490491

491-
if not is_valid_model_id(
492-
model_id=model_id,
493-
model_version=model_version,
494-
region=region,
495-
script=JumpStartScriptScope.TRAINING,
496-
):
497-
raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id))
492+
def _is_valid_model_id_hook():
493+
return is_valid_model_id(
494+
model_id=model_id,
495+
model_version=model_version,
496+
region=region,
497+
script=JumpStartScriptScope.TRAINING,
498+
)
499+
500+
if not _is_valid_model_id_hook():
501+
JumpStartModelsAccessor.reset_cache()
502+
if not _is_valid_model_id_hook():
503+
raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id))
498504

499505
estimator_init_kwargs = get_init_kwargs(
500506
model_id=model_id,

src/sagemaker/jumpstart/model.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sagemaker.base_deserializers import BaseDeserializer
2121
from sagemaker.base_serializers import BaseSerializer
2222
from sagemaker.explainer.explainer_config import ExplainerConfig
23+
from sagemaker.jumpstart.accessors import JumpStartModelsAccessor
2324
from sagemaker.jumpstart.enums import JumpStartScriptScope
2425
from sagemaker.jumpstart.exceptions import INVALID_MODEL_ID_ERROR_MSG
2526
from sagemaker.jumpstart.factory.model import (
@@ -252,13 +253,18 @@ def __init__(
252253
ValueError: If the model ID is not recognized by JumpStart.
253254
"""
254255

255-
if not is_valid_model_id(
256-
model_id=model_id,
257-
model_version=model_version,
258-
region=region,
259-
script=JumpStartScriptScope.INFERENCE,
260-
):
261-
raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id))
256+
def _is_valid_model_id_hook():
257+
return is_valid_model_id(
258+
model_id=model_id,
259+
model_version=model_version,
260+
region=region,
261+
script=JumpStartScriptScope.INFERENCE,
262+
)
263+
264+
if not _is_valid_model_id_hook():
265+
JumpStartModelsAccessor.reset_cache()
266+
if not _is_valid_model_id_hook():
267+
raise ValueError(INVALID_MODEL_ID_ERROR_MSG.format(model_id=model_id))
262268

263269
model_init_kwargs = get_init_kwargs(
264270
model_id=model_id,

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sagemaker.debugger.profiler_config import ProfilerConfig
2323
from sagemaker.estimator import Estimator
2424
from sagemaker.instance_group import InstanceGroup
25+
from sagemaker.jumpstart.enums import JumpStartScriptScope
2526

2627
from sagemaker.jumpstart.estimator import JumpStartEstimator
2728

@@ -940,6 +941,85 @@ def test_training_passes_session_to_deploy(
940941
endpoint_name="blahblahblah-3456",
941942
)
942943

944+
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
945+
@mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy")
946+
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
947+
@mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_init_kwargs")
948+
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
949+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
950+
@mock.patch("sagemaker.jumpstart.estimator.JumpStartModelsAccessor.reset_cache")
951+
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
952+
def test_model_id_not_found_refeshes_cache_training(
953+
self,
954+
mock_reset_cache: mock.Mock,
955+
mock_get_model_specs: mock.Mock,
956+
mock_session: mock.Mock,
957+
mock_retrieve_kwargs: mock.Mock,
958+
mock_estimator_init: mock.Mock,
959+
mock_estimator_deploy: mock.Mock,
960+
mock_is_valid_model_id: mock.Mock,
961+
):
962+
mock_estimator_deploy.return_value = default_predictor
963+
964+
mock_is_valid_model_id.side_effect = [False, False]
965+
966+
model_id, _ = "js-trainable-model", "*"
967+
968+
mock_retrieve_kwargs.return_value = {}
969+
970+
mock_get_model_specs.side_effect = get_special_model_spec
971+
972+
mock_session.return_value = sagemaker_session
973+
974+
with pytest.raises(ValueError):
975+
JumpStartEstimator(
976+
model_id=model_id,
977+
)
978+
979+
mock_reset_cache.assert_called_once_with()
980+
mock_is_valid_model_id.assert_has_calls(
981+
calls=[
982+
mock.call(
983+
model_id="js-trainable-model",
984+
model_version=None,
985+
region=None,
986+
script=JumpStartScriptScope.TRAINING,
987+
),
988+
mock.call(
989+
model_id="js-trainable-model",
990+
model_version=None,
991+
region=None,
992+
script=JumpStartScriptScope.TRAINING,
993+
),
994+
]
995+
)
996+
997+
mock_is_valid_model_id.reset_mock()
998+
mock_reset_cache.reset_mock()
999+
1000+
mock_is_valid_model_id.side_effect = [False, True]
1001+
JumpStartEstimator(
1002+
model_id=model_id,
1003+
)
1004+
1005+
mock_reset_cache.assert_called_once_with()
1006+
mock_is_valid_model_id.assert_has_calls(
1007+
calls=[
1008+
mock.call(
1009+
model_id="js-trainable-model",
1010+
model_version=None,
1011+
region=None,
1012+
script=JumpStartScriptScope.TRAINING,
1013+
),
1014+
mock.call(
1015+
model_id="js-trainable-model",
1016+
model_version=None,
1017+
region=None,
1018+
script=JumpStartScriptScope.TRAINING,
1019+
),
1020+
]
1021+
)
1022+
9431023

9441024
def test_jumpstart_estimator_requires_model_id():
9451025
with pytest.raises(ValueError):

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from unittest import mock
1717
import unittest
1818
import pytest
19+
from sagemaker.jumpstart.enums import JumpStartScriptScope
1920

2021
from sagemaker.jumpstart.model import JumpStartModel
2122
from sagemaker.model import Model
@@ -452,6 +453,82 @@ def test_yes_predictor_returns_default_predictor(
452453
self.assertEqual(type(predictor), Predictor)
453454
self.assertEqual(predictor, default_predictor)
454455

456+
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
457+
@mock.patch("sagemaker.jumpstart.model.Model.__init__")
458+
@mock.patch("sagemaker.jumpstart.factory.model._retrieve_model_init_kwargs")
459+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
460+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
461+
@mock.patch("sagemaker.jumpstart.model.JumpStartModelsAccessor.reset_cache")
462+
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
463+
def test_model_id_not_found_refeshes_cach_inference(
464+
self,
465+
mock_reset_cache: mock.Mock,
466+
mock_get_model_specs: mock.Mock,
467+
mock_session: mock.Mock,
468+
mock_retrieve_kwargs: mock.Mock,
469+
mock_model_init: mock.Mock,
470+
mock_is_valid_model_id: mock.Mock,
471+
):
472+
473+
mock_is_valid_model_id.side_effect = [False, False]
474+
475+
model_id, _ = "js-trainable-model", "*"
476+
477+
mock_retrieve_kwargs.return_value = {}
478+
479+
mock_get_model_specs.side_effect = get_special_model_spec
480+
481+
mock_session.return_value = sagemaker_session
482+
483+
with pytest.raises(ValueError):
484+
JumpStartModel(
485+
model_id=model_id,
486+
)
487+
488+
mock_reset_cache.assert_called_once_with()
489+
mock_is_valid_model_id.assert_has_calls(
490+
calls=[
491+
mock.call(
492+
model_id="js-trainable-model",
493+
model_version=None,
494+
region=None,
495+
script=JumpStartScriptScope.INFERENCE,
496+
),
497+
mock.call(
498+
model_id="js-trainable-model",
499+
model_version=None,
500+
region=None,
501+
script=JumpStartScriptScope.INFERENCE,
502+
),
503+
]
504+
)
505+
506+
mock_is_valid_model_id.reset_mock()
507+
mock_reset_cache.reset_mock()
508+
509+
mock_is_valid_model_id.side_effect = [False, True]
510+
JumpStartModel(
511+
model_id=model_id,
512+
)
513+
514+
mock_reset_cache.assert_called_once_with()
515+
mock_is_valid_model_id.assert_has_calls(
516+
calls=[
517+
mock.call(
518+
model_id="js-trainable-model",
519+
model_version=None,
520+
region=None,
521+
script=JumpStartScriptScope.INFERENCE,
522+
),
523+
mock.call(
524+
model_id="js-trainable-model",
525+
model_version=None,
526+
region=None,
527+
script=JumpStartScriptScope.INFERENCE,
528+
),
529+
]
530+
)
531+
455532

456533
def test_jumpstart_model_requires_model_id():
457534
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)