diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index c7396cdec5..baa9d55085 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -336,6 +336,7 @@ def get_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, training_instance_type=training_instance_type, + disable_instance_type_logging=True, ) estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs( diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 009fc699de..185beefc59 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -171,7 +171,9 @@ def _add_vulnerable_and_deprecated_status_to_kwargs( return kwargs -def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: +def _add_instance_type_to_kwargs( + kwargs: JumpStartModelInitKwargs, disable_instance_type_logging: bool = False +) -> JumpStartModelInitKwargs: """Sets instance type based on default or override, returns full kwargs.""" orig_instance_type = kwargs.instance_type @@ -187,7 +189,7 @@ def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartM training_instance_type=kwargs.training_instance_type, ) - if orig_instance_type is None: + if not disable_instance_type_logging and orig_instance_type is None: JUMPSTART_LOGGER.info( "No instance type selected for inference hosting endpoint. Defaulting to %s.", kwargs.instance_type, @@ -551,9 +553,7 @@ def get_deploy_kwargs( deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) - deploy_kwargs = _add_instance_type_to_kwargs( - kwargs=deploy_kwargs, - ) + deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs.initial_instance_count = initial_instance_count or 1 @@ -677,6 +677,7 @@ def get_init_kwargs( git_config: Optional[Dict[str, str]] = None, model_package_arn: Optional[str] = None, training_instance_type: Optional[str] = None, + disable_instance_type_logging: bool = False, resources: Optional[ResourceRequirements] = None, ) -> JumpStartModelInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Model` object.""" @@ -720,7 +721,7 @@ def get_init_kwargs( model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_instance_type_to_kwargs( - kwargs=model_init_kwargs, + kwargs=model_init_kwargs, disable_instance_type_logging=disable_instance_type_logging ) model_init_kwargs = _add_image_uri_to_kwargs(kwargs=model_init_kwargs) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 9476e3c1fd..29eff40461 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -57,6 +57,8 @@ class EstimatorTest(unittest.TestCase): + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -77,6 +79,8 @@ def test_non_prepacked( mock_session_model: mock.Mock, mock_is_valid_model_id: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_jumpstart_model_factory_logger: mock.Mock, + mock_jumpstart_estimator_factory_logger: mock.Mock, ): mock_is_valid_model_id.return_value = True @@ -94,6 +98,9 @@ def test_non_prepacked( estimator = JumpStartEstimator( model_id=model_id, ) + mock_jumpstart_estimator_factory_logger.info.assert_called_once_with( + "No instance type selected for training job. Defaulting to %s.", "ml.p3.2xlarge" + ) mock_estimator_init.assert_called_once_with( instance_type="ml.p3.2xlarge", @@ -131,13 +138,22 @@ def test_non_prepacked( f"{get_training_dataset_for_model_and_version(model_id, model_version)}", } + mock_jumpstart_estimator_factory_logger.info.reset_mock() estimator.fit(channels) + mock_jumpstart_estimator_factory_logger.info.assert_not_called() mock_estimator_fit.assert_called_once_with( inputs=channels, wait=True, job_name="blahblahblah-9876" ) + mock_jumpstart_model_factory_logger.info.reset_mock() + mock_jumpstart_estimator_factory_logger.info.reset_mock() estimator.deploy() + mock_jumpstart_model_factory_logger.info.assert_called_once_with( + "No instance type selected for inference hosting endpoint. Defaulting to %s.", + "ml.p2.xlarge", + ) + mock_jumpstart_estimator_factory_logger.info.assert_not_called() mock_estimator_deploy.assert_called_once_with( instance_type="ml.p2.xlarge", diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index ff3e670e53..f45283935b 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -51,6 +51,7 @@ class ModelTest(unittest.TestCase): mock_session_empty_config = MagicMock(sagemaker_config={}) + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -66,6 +67,7 @@ def test_non_prepacked( mock_session: mock.Mock, mock_is_valid_model_id: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_jumpstart_model_factory_logger: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -78,9 +80,14 @@ def test_non_prepacked( mock_session.return_value = sagemaker_session + mock_jumpstart_model_factory_logger.info.reset_mock() model = JumpStartModel( model_id=model_id, ) + mock_jumpstart_model_factory_logger.info.assert_called_once_with( + "No " "instance type selected for inference hosting endpoint. " "Defaulting to %s.", + "ml.p2.xlarge", + ) mock_model_init.assert_called_once_with( image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/" @@ -104,7 +111,9 @@ def test_non_prepacked( name="blahblahblah-7777", ) + mock_jumpstart_model_factory_logger.info.reset_mock() model.deploy() + mock_jumpstart_model_factory_logger.info.assert_not_called() mock_model_deploy.assert_called_once_with( initial_instance_count=1,