Skip to content

Commit 3153f5a

Browse files
authored
fix: excessive jumpstart logging (#4023)
1 parent 219ad24 commit 3153f5a

File tree

9 files changed

+124
-119
lines changed

9 files changed

+124
-119
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,10 @@ def _retrieval_function(
327327
)
328328
if file_type == JumpStartS3FileType.SPECS:
329329
formatted_body, _ = self._get_json_file(s3_key, file_type)
330+
model_specs = JumpStartModelSpecs(formatted_body)
331+
utils.emit_logs_based_on_model_specs(model_specs, self.get_region())
330332
return JumpStartCachedS3ContentValue(
331-
formatted_content=JumpStartModelSpecs(formatted_body)
333+
formatted_content=model_specs
332334
)
333335
raise ValueError(
334336
f"Bad value for key '{key}': must be in {[JumpStartS3FileType.MANIFEST, JumpStartS3FileType.SPECS]}"

src/sagemaker/jumpstart/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module stores constants related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
15+
import logging
1516
from typing import Dict, Set, Type
1617
import boto3
1718
from sagemaker.base_deserializers import BaseDeserializer, JSONDeserializer
@@ -173,3 +174,5 @@
173174
}
174175

175176
MODEL_ID_LIST_WEB_URL = "https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html"
177+
178+
JUMPSTART_LOGGER = logging.getLogger("sagemaker.jumpstart")

src/sagemaker/jumpstart/estimator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module stores JumpStart implementation of Estimator class."""
1414
from __future__ import absolute_import
15-
import logging
1615

1716

1817
from typing import Dict, List, Optional, Union
@@ -45,8 +44,6 @@
4544
from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
4645
from sagemaker.workflow.entities import PipelineVariable
4746

48-
logger = logging.getLogger(__name__)
49-
5047

5148
class JumpStartEstimator(Estimator):
5249
"""JumpStartEstimator class.

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module stores JumpStart Estimator factory methods."""
1414
from __future__ import absolute_import
15-
import logging
1615

1716

1817
from typing import Dict, List, Optional, Union
@@ -41,6 +40,7 @@
4140
)
4241
from sagemaker.jumpstart.constants import (
4342
JUMPSTART_DEFAULT_REGION_NAME,
43+
JUMPSTART_LOGGER,
4444
TRAINING_ENTRY_POINT_SCRIPT_NAME,
4545
)
4646
from sagemaker.jumpstart.enums import JumpStartScriptScope
@@ -64,8 +64,6 @@
6464
from sagemaker.utils import name_from_base
6565
from sagemaker.workflow.entities import PipelineVariable
6666

67-
logger = logging.getLogger("sagemaker")
68-
6967

7068
def get_init_kwargs(
7169
model_id: str,
@@ -421,7 +419,7 @@ def _add_instance_type_and_count_to_kwargs(
421419
kwargs.instance_count = kwargs.instance_count or 1
422420

423421
if orig_instance_type is None:
424-
logger.info(
422+
JUMPSTART_LOGGER.info(
425423
"No instance type selected for training job. Defaulting to %s.", kwargs.instance_type
426424
)
427425

@@ -467,7 +465,7 @@ def _add_model_uri_to_kwargs(kwargs: JumpStartEstimatorInitKwargs) -> JumpStartE
467465
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
468466
)
469467
):
470-
logger.warning(
468+
JUMPSTART_LOGGER.warning(
471469
"'%s' does not support incremental training but is being trained with"
472470
" non-default model artifact.",
473471
kwargs.model_id,

src/sagemaker/jumpstart/factory/model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module stores JumpStart Model factory methods."""
1414
from __future__ import absolute_import
15-
import logging
1615

1716

1817
from typing import Any, Dict, List, Optional, Union
@@ -31,6 +30,7 @@
3130
from sagemaker.jumpstart.constants import (
3231
INFERENCE_ENTRY_POINT_SCRIPT_NAME,
3332
JUMPSTART_DEFAULT_REGION_NAME,
33+
JUMPSTART_LOGGER,
3434
)
3535
from sagemaker.jumpstart.enums import JumpStartScriptScope
3636
from sagemaker.jumpstart.types import (
@@ -51,8 +51,6 @@
5151
from sagemaker.utils import name_from_base
5252
from sagemaker.workflow.entities import PipelineVariable
5353

54-
logger = logging.getLogger("sagemaker")
55-
5654

5755
def get_default_predictor(
5856
predictor: Predictor,
@@ -170,7 +168,7 @@ def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartM
170168
)
171169

172170
if orig_instance_type is None:
173-
logger.info(
171+
JUMPSTART_LOGGER.info(
174172
"No instance type selected for inference hosting endpoint. Defaulting to %s.",
175173
kwargs.instance_type,
176174
)

src/sagemaker/jumpstart/model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
"""This module stores JumpStart implementation of Model class."""
1414

1515
from __future__ import absolute_import
16-
import logging
1716
import re
1817

1918
from typing import Dict, List, Optional, Union
@@ -38,8 +37,6 @@
3837
from sagemaker.session import Session
3938
from sagemaker.workflow.entities import PipelineVariable
4039

41-
logger = logging.getLogger(__name__)
42-
4340

4441
class JumpStartModel(Model):
4542
"""JumpStartModel class.

src/sagemaker/jumpstart/utils.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@
4242
from sagemaker.utils import resolve_value_from_config
4343
from sagemaker.workflow import is_pipeline_variable
4444

45-
LOGGER = logging.getLogger(__name__)
46-
4745

4846
def get_jumpstart_launched_regions_message() -> str:
4947
"""Returns formatted string indicating where JumpStart is launched."""
@@ -79,7 +77,7 @@ def get_jumpstart_content_bucket(
7977
and len(os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]) > 0
8078
):
8179
bucket_override = os.environ[constants.ENV_VARIABLE_JUMPSTART_CONTENT_BUCKET_OVERRIDE]
82-
LOGGER.info("Using JumpStart bucket override: '%s'", bucket_override)
80+
constants.JUMPSTART_LOGGER.info("Using JumpStart bucket override: '%s'", bucket_override)
8381
return bucket_override
8482
try:
8583
return constants.JUMPSTART_REGION_NAME_TO_LAUNCHED_REGION_DICT[region].content_bucket
@@ -343,6 +341,39 @@ def update_inference_tags_with_jumpstart_training_tags(
343341
return inference_tags
344342

345343

344+
def emit_logs_based_on_model_specs(model_specs: JumpStartModelSpecs, region: str) -> None:
345+
"""Emits logs based on model specs and region."""
346+
347+
if model_specs.hosting_eula_key:
348+
constants.JUMPSTART_LOGGER.info(
349+
"Model '%s' requires accepting end-user license agreement (EULA). "
350+
"See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.",
351+
model_specs.model_id,
352+
get_jumpstart_content_bucket(region=region),
353+
region,
354+
".cn" if region.startswith("cn-") else "",
355+
model_specs.hosting_eula_key,
356+
)
357+
358+
if model_specs.deprecated:
359+
deprecated_message = model_specs.deprecated_message or (
360+
"Using deprecated JumpStart model "
361+
f"'{model_specs.model_id}' and version '{model_specs.version}'."
362+
)
363+
364+
constants.JUMPSTART_LOGGER.warning(deprecated_message)
365+
366+
if model_specs.deprecate_warn_message:
367+
constants.JUMPSTART_LOGGER.warning(model_specs.deprecate_warn_message)
368+
369+
if model_specs.inference_vulnerable or model_specs.training_vulnerable:
370+
constants.JUMPSTART_LOGGER.warning(
371+
"Using vulnerable JumpStart model '%s' and version '%s'.",
372+
model_specs.model_id,
373+
model_specs.version,
374+
)
375+
376+
346377
def verify_model_region_and_return_specs(
347378
model_id: Optional[str],
348379
version: Optional[str],
@@ -402,26 +433,11 @@ def verify_model_region_and_return_specs(
402433
f"JumpStart model ID '{model_id}' and version '{version}' " "does not support training."
403434
)
404435

405-
if model_specs.hosting_eula_key and scope == constants.JumpStartScriptScope.INFERENCE.value:
406-
LOGGER.info(
407-
"Model '%s' requires accepting end-user license agreement (EULA). "
408-
"See https://%s.s3.%s.amazonaws.com%s/%s for terms of use.",
409-
model_id,
410-
get_jumpstart_content_bucket(region=region),
411-
region,
412-
".cn" if region.startswith("cn-") else "",
413-
model_specs.hosting_eula_key,
414-
)
415-
416436
if model_specs.deprecated:
417437
if not tolerate_deprecated_model:
418438
raise DeprecatedJumpStartModelError(
419439
model_id=model_id, version=version, message=model_specs.deprecated_message
420440
)
421-
LOGGER.warning("Using deprecated JumpStart model '%s' and version '%s'.", model_id, version)
422-
423-
if model_specs.deprecate_warn_message:
424-
LOGGER.warning(model_specs.deprecate_warn_message)
425441

426442
if scope == constants.JumpStartScriptScope.INFERENCE.value and model_specs.inference_vulnerable:
427443
if not tolerate_vulnerable_model:
@@ -431,9 +447,6 @@ def verify_model_region_and_return_specs(
431447
vulnerabilities=model_specs.inference_vulnerabilities,
432448
scope=constants.JumpStartScriptScope.INFERENCE,
433449
)
434-
LOGGER.warning(
435-
"Using vulnerable JumpStart model '%s' and version '%s' (inference).", model_id, version
436-
)
437450

438451
if scope == constants.JumpStartScriptScope.TRAINING.value and model_specs.training_vulnerable:
439452
if not tolerate_vulnerable_model:
@@ -443,9 +456,6 @@ def verify_model_region_and_return_specs(
443456
vulnerabilities=model_specs.training_vulnerabilities,
444457
scope=constants.JumpStartScriptScope.TRAINING,
445458
)
446-
LOGGER.warning(
447-
"Using vulnerable JumpStart model '%s' and version '%s' (training).", model_id, version
448-
)
449459

450460
return model_specs
451461

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def test_yes_predictor_returns_unmodified_predictor(
742742

743743
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
744744
@mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training")
745-
@mock.patch("sagemaker.jumpstart.factory.estimator.logger.warning")
745+
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning")
746746
@mock.patch("sagemaker.jumpstart.factory.model.Session")
747747
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
748748
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@@ -795,7 +795,7 @@ def test_incremental_training_with_unsupported_model_logs_warning(
795795

796796
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
797797
@mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training")
798-
@mock.patch("sagemaker.jumpstart.factory.estimator.logger.warning")
798+
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning")
799799
@mock.patch("sagemaker.jumpstart.factory.model.Session")
800800
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
801801
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")

0 commit comments

Comments
 (0)