Skip to content

Commit de641fe

Browse files
committed
fix: excessive jumpstart instance type logging
1 parent 8462f1a commit de641fe

File tree

4 files changed

+33
-6
lines changed

4 files changed

+33
-6
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def get_deploy_kwargs(
335335
tolerate_vulnerable_model=tolerate_vulnerable_model,
336336
tolerate_deprecated_model=tolerate_deprecated_model,
337337
training_instance_type=training_instance_type,
338+
disable_instance_type_logging=True,
338339
)
339340

340341
estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs(

src/sagemaker/jumpstart/factory/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,9 @@ def _add_vulnerable_and_deprecated_status_to_kwargs(
168168
return kwargs
169169

170170

171-
def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
171+
def _add_instance_type_to_kwargs(
172+
kwargs: JumpStartModelInitKwargs, disable_logging: bool = False
173+
) -> JumpStartModelInitKwargs:
172174
"""Sets instance type based on default or override, returns full kwargs."""
173175

174176
orig_instance_type = kwargs.instance_type
@@ -184,7 +186,7 @@ def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartM
184186
training_instance_type=kwargs.training_instance_type,
185187
)
186188

187-
if orig_instance_type is None:
189+
if not disable_logging and orig_instance_type is None:
188190
JUMPSTART_LOGGER.info(
189191
"No instance type selected for inference hosting endpoint. Defaulting to %s.",
190192
kwargs.instance_type,
@@ -524,9 +526,7 @@ def get_deploy_kwargs(
524526

525527
deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
526528

527-
deploy_kwargs = _add_instance_type_to_kwargs(
528-
kwargs=deploy_kwargs,
529-
)
529+
deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)
530530

531531
deploy_kwargs.initial_instance_count = initial_instance_count or 1
532532

@@ -645,6 +645,7 @@ def get_init_kwargs(
645645
git_config: Optional[Dict[str, str]] = None,
646646
model_package_arn: Optional[str] = None,
647647
training_instance_type: Optional[str] = None,
648+
disable_instance_type_logging: bool = False,
648649
) -> JumpStartModelInitKwargs:
649650
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""
650651

@@ -686,7 +687,7 @@ def get_init_kwargs(
686687
model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs)
687688

688689
model_init_kwargs = _add_instance_type_to_kwargs(
689-
kwargs=model_init_kwargs,
690+
kwargs=model_init_kwargs, disable_logging=disable_instance_type_logging
690691
)
691692

692693
model_init_kwargs = _add_image_uri_to_kwargs(kwargs=model_init_kwargs)

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@
5050

5151

5252
class EstimatorTest(unittest.TestCase):
53+
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER")
54+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER")
5355
@mock.patch("sagemaker.utils.sagemaker_timestamp")
5456
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
5557
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@@ -70,6 +72,8 @@ def test_non_prepacked(
7072
mock_session_model: mock.Mock,
7173
mock_is_valid_model_id: mock.Mock,
7274
mock_sagemaker_timestamp: mock.Mock,
75+
mock_jumpstart_model_factory_logger: mock.Mock,
76+
mock_jumpstart_estimator_factory_logger: mock.Mock,
7377
):
7478
mock_is_valid_model_id.return_value = True
7579

@@ -87,6 +91,9 @@ def test_non_prepacked(
8791
estimator = JumpStartEstimator(
8892
model_id=model_id,
8993
)
94+
mock_jumpstart_estimator_factory_logger.info.assert_called_once_with(
95+
"No instance type selected for training job. Defaulting to %s.", "ml.p3.2xlarge"
96+
)
9097

9198
mock_estimator_init.assert_called_once_with(
9299
instance_type="ml.p3.2xlarge",
@@ -124,13 +131,22 @@ def test_non_prepacked(
124131
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
125132
}
126133

134+
mock_jumpstart_estimator_factory_logger.info.reset_mock()
127135
estimator.fit(channels)
136+
mock_jumpstart_estimator_factory_logger.info.assert_not_called()
128137

129138
mock_estimator_fit.assert_called_once_with(
130139
inputs=channels, wait=True, job_name="blahblahblah-9876"
131140
)
132141

142+
mock_jumpstart_model_factory_logger.info.reset_mock()
143+
mock_jumpstart_estimator_factory_logger.info.reset_mock()
133144
estimator.deploy()
145+
mock_jumpstart_model_factory_logger.info.assert_called_once_with(
146+
"No instance type selected for inference hosting endpoint. Defaulting to %s.",
147+
"ml.p2.xlarge",
148+
)
149+
mock_jumpstart_estimator_factory_logger.info.assert_not_called()
134150

135151
mock_estimator_deploy.assert_called_once_with(
136152
instance_type="ml.p2.xlarge",

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class ModelTest(unittest.TestCase):
4343

4444
mock_session_empty_config = MagicMock(sagemaker_config={})
4545

46+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER")
4647
@mock.patch("sagemaker.utils.sagemaker_timestamp")
4748
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
4849
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@@ -58,6 +59,7 @@ def test_non_prepacked(
5859
mock_session: mock.Mock,
5960
mock_is_valid_model_id: mock.Mock,
6061
mock_sagemaker_timestamp: mock.Mock,
62+
mock_jumpstart_model_factory_logger: mock.Mock,
6163
):
6264
mock_model_deploy.return_value = default_predictor
6365

@@ -70,9 +72,14 @@ def test_non_prepacked(
7072

7173
mock_session.return_value = sagemaker_session
7274

75+
mock_jumpstart_model_factory_logger.info.reset_mock()
7376
model = JumpStartModel(
7477
model_id=model_id,
7578
)
79+
mock_jumpstart_model_factory_logger.info.assert_called_once_with(
80+
"No " "instance type selected for inference hosting endpoint. " "Defaulting to %s.",
81+
"ml.p2.xlarge",
82+
)
7683

7784
mock_model_init.assert_called_once_with(
7885
image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/"
@@ -96,7 +103,9 @@ def test_non_prepacked(
96103
name="blahblahblah-7777",
97104
)
98105

106+
mock_jumpstart_model_factory_logger.info.reset_mock()
99107
model.deploy()
108+
mock_jumpstart_model_factory_logger.info.assert_not_called()
100109

101110
mock_model_deploy.assert_called_once_with(
102111
initial_instance_count=1,

0 commit comments

Comments
 (0)