Skip to content

Commit 738df16

Browse files
committed
chore: improve error msg for deprecated and vulnerable models, fix pylint
1 parent b6e2cfd commit 738df16

File tree

6 files changed

+33
-22
lines changed

6 files changed

+33
-22
lines changed

src/sagemaker/jumpstart/exceptions.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,17 @@ def __init__(
8181
self.message = (
8282
f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore
8383
"has at least 1 vulnerable dependency in the inference script. "
84-
"Please try targetting a higher version of the model. "
85-
f"List of vulnerabilities: {', '.join(vulnerabilities)}" # type: ignore
84+
"Please try targetting a higher version of the model or using a "
85+
"different model. List of vulnerabilities: "
86+
f"{', '.join(vulnerabilities)}" # type: ignore
8687
)
8788
elif scope == JumpStartScriptScope.TRAINING:
8889
self.message = (
8990
f"Version '{version}' of JumpStart model '{model_id}' " # type: ignore
9091
"has at least 1 vulnerable dependency in the training script. "
91-
"Please try targetting a higher version of the model. "
92-
f"List of vulnerabilities: {', '.join(vulnerabilities)}" # type: ignore
92+
"Please try targetting a higher version of the model or using a "
93+
"different model. List of vulnerabilities: "
94+
f"{', '.join(vulnerabilities)}" # type: ignore
9395
)
9496
else:
9597
raise NotImplementedError(
@@ -123,7 +125,8 @@ def __init__(
123125
raise RuntimeError("Must specify `model_id` and `version` arguments.")
124126
self.message = (
125127
f"Version '{version}' of JumpStart model '{model_id}' is deprecated. "
126-
"Please try targetting a higher version of the model."
128+
"Please try targetting a higher version of the model or using a "
129+
"different model."
127130
)
128131

129132
super().__init__(self.message)

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,8 @@ def _add_instance_type_and_count_to_kwargs(
412412
kwargs.instance_count = kwargs.instance_count or 1
413413

414414
if orig_instance_type is None:
415-
logger.info( # pylint: disable=W1203
416-
f"No instance type selected for training job. Defaulting to {kwargs.instance_type}."
415+
logger.info(
416+
"No instance type selected for training job. Defaulting to %s.", kwargs.instance_type
417417
)
418418

419419
return kwargs
@@ -458,9 +458,10 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
458458
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
459459
)
460460
):
461-
logger.warning( # pylint: disable=W1203
462-
f"'{kwargs.model_id}' does not support incremental training but is being trained with"
463-
" non-default model artifact."
461+
logger.warning(
462+
"'%s' does not support incremental training but is being trained with"
463+
" non-default model artifact.",
464+
kwargs.model_id,
464465
)
465466

466467
kwargs.model_uri = kwargs.model_uri or default_model_uri

src/sagemaker/jumpstart/factory/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartM
168168
)
169169

170170
if orig_instance_type is None:
171-
logger.info( # pylint: disable=W1203
172-
"No instance type selected for inference hosting endpoint. "
173-
f"Defaulting to {kwargs.instance_type}."
171+
logger.info(
172+
"No instance type selected for inference hosting endpoint. Defaulting to %s.",
173+
kwargs.instance_type,
174174
)
175175

176176
return kwargs

src/sagemaker/model.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -612,11 +612,15 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
612612
s3_prefix=repacked_model_data, script_name=os.path.basename(self.entry_point)
613613
)
614614

615-
LOGGER.info( # pylint: disable=W1203
616-
f"Repacking model artifact ({self.model_data}), script artifact "
617-
f"({self.source_dir}), and dependencies ({self.dependencies}) "
618-
f"into single tar.gz file located at {repacked_model_data}. "
619-
"This may take some time depending on model size..."
615+
LOGGER.info(
616+
"Repacking model artifact (%s), script artifact "
617+
"(%s), and dependencies (%s) "
618+
"into single tar.gz file located at %s. "
619+
"This may take some time depending on model size...",
620+
self.model_data,
621+
self.source_dir,
622+
self.dependencies,
623+
repacked_model_data,
620624
)
621625

622626
utils.repack_model(

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,8 @@ def test_incremental_training_with_unsupported_model_logs_warning(
715715
)
716716

717717
mock_logger_warning.assert_called_once_with(
718-
f"'{model_id}' does not support incremental training but is being trained with non-default model artifact."
718+
"'%s' does not support incremental training but is being trained with non-default model artifact.",
719+
model_id,
719720
)
720721
mock_supports_incremental_training.assert_called_once_with(
721722
model_id=model_id,

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ def make_vulnerable_inference_spec(*largs, **kwargs):
788788
assert (
789789
"Version '*' of JumpStart model 'pytorch-eqa-bert-base-cased' has at least 1 "
790790
"vulnerable dependency in the inference script. "
791-
"Please try targetting a higher version of the model. "
791+
"Please try targetting a higher version of the model or using a different model. "
792792
"List of vulnerabilities: some, vulnerability"
793793
) == str(e.value.message)
794794

@@ -827,7 +827,7 @@ def make_vulnerable_training_spec(*largs, **kwargs):
827827
assert (
828828
"Version '*' of JumpStart model 'pytorch-eqa-bert-base-cased' has at least 1 "
829829
"vulnerable dependency in the training script. "
830-
"Please try targetting a higher version of the model. "
830+
"Please try targetting a higher version of the model or using a different model. "
831831
"List of vulnerabilities: some, vulnerability"
832832
) == str(e.value.message)
833833

@@ -866,7 +866,9 @@ def make_deprecated_spec(*largs, **kwargs):
866866
region="us-west-2",
867867
)
868868
assert "Version '*' of JumpStart model 'pytorch-eqa-bert-base-cased' is deprecated. "
869-
"Please try targetting a higher version of the model." == str(e.value.message)
869+
"Please try targetting a higher version of the model or using a different model." == str(
870+
e.value.message
871+
)
870872

871873
with patch("logging.Logger.warning") as mocked_warning_log:
872874
assert (

0 commit comments

Comments
 (0)