From de641fe498e0cd499931076b0f8c55e30d85a76c Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 13 Nov 2023 16:46:53 +0000 Subject: [PATCH 1/2] fix: excessive jumpstart instance type logging --- src/sagemaker/jumpstart/factory/estimator.py | 1 + src/sagemaker/jumpstart/factory/model.py | 13 +++++++------ .../jumpstart/estimator/test_estimator.py | 16 ++++++++++++++++ .../unit/sagemaker/jumpstart/model/test_model.py | 9 +++++++++ 4 files changed, 33 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 1b24b714e7..c78fc757ae 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -335,6 +335,7 @@ def get_deploy_kwargs( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, training_instance_type=training_instance_type, + disable_instance_type_logging=True, ) estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs( diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 19605774ed..8e891ec818 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -168,7 +168,9 @@ def _add_vulnerable_and_deprecated_status_to_kwargs( return kwargs -def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs: +def _add_instance_type_to_kwargs( + kwargs: JumpStartModelInitKwargs, disable_logging: bool = False +) -> JumpStartModelInitKwargs: """Sets instance type based on default or override, returns full kwargs.""" orig_instance_type = kwargs.instance_type @@ -184,7 +186,7 @@ def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartM training_instance_type=kwargs.training_instance_type, ) - if orig_instance_type is None: + if not disable_logging and orig_instance_type is None: JUMPSTART_LOGGER.info( "No instance type selected for inference hosting endpoint. Defaulting to %s.", kwargs.instance_type, @@ -524,9 +526,7 @@ def get_deploy_kwargs( deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) - deploy_kwargs = _add_instance_type_to_kwargs( - kwargs=deploy_kwargs, - ) + deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs) deploy_kwargs.initial_instance_count = initial_instance_count or 1 @@ -645,6 +645,7 @@ def get_init_kwargs( git_config: Optional[Dict[str, str]] = None, model_package_arn: Optional[str] = None, training_instance_type: Optional[str] = None, + disable_instance_type_logging: bool = False, ) -> JumpStartModelInitKwargs: """Returns kwargs required to instantiate `sagemaker.estimator.Model` object.""" @@ -686,7 +687,7 @@ def get_init_kwargs( model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_instance_type_to_kwargs( - kwargs=model_init_kwargs, + kwargs=model_init_kwargs, disable_logging=disable_instance_type_logging ) model_init_kwargs = _add_image_uri_to_kwargs(kwargs=model_init_kwargs) diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 6f4788fa04..bd4df9b8f9 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -50,6 +50,8 @@ class EstimatorTest(unittest.TestCase): + @mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER") + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -70,6 +72,8 @@ def test_non_prepacked( mock_session_model: mock.Mock, mock_is_valid_model_id: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_jumpstart_model_factory_logger: mock.Mock, + mock_jumpstart_estimator_factory_logger: mock.Mock, ): mock_is_valid_model_id.return_value = True @@ -87,6 +91,9 @@ def test_non_prepacked( estimator = JumpStartEstimator( model_id=model_id, ) + mock_jumpstart_estimator_factory_logger.info.assert_called_once_with( + "No instance type selected for training job. Defaulting to %s.", "ml.p3.2xlarge" + ) mock_estimator_init.assert_called_once_with( instance_type="ml.p3.2xlarge", @@ -124,13 +131,22 @@ def test_non_prepacked( f"{get_training_dataset_for_model_and_version(model_id, model_version)}", } + mock_jumpstart_estimator_factory_logger.info.reset_mock() estimator.fit(channels) + mock_jumpstart_estimator_factory_logger.info.assert_not_called() mock_estimator_fit.assert_called_once_with( inputs=channels, wait=True, job_name="blahblahblah-9876" ) + mock_jumpstart_model_factory_logger.info.reset_mock() + mock_jumpstart_estimator_factory_logger.info.reset_mock() estimator.deploy() + mock_jumpstart_model_factory_logger.info.assert_called_once_with( + "No instance type selected for inference hosting endpoint. Defaulting to %s.", + "ml.p2.xlarge", + ) + mock_jumpstart_estimator_factory_logger.info.assert_not_called() mock_estimator_deploy.assert_called_once_with( instance_type="ml.p2.xlarge", diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index 3b6cbff2ad..560c29bfc3 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -43,6 +43,7 @@ class ModelTest(unittest.TestCase): mock_session_empty_config = MagicMock(sagemaker_config={}) + @mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER") @mock.patch("sagemaker.utils.sagemaker_timestamp") @mock.patch("sagemaker.jumpstart.model.is_valid_model_id") @mock.patch("sagemaker.jumpstart.factory.model.Session") @@ -58,6 +59,7 @@ def test_non_prepacked( mock_session: mock.Mock, mock_is_valid_model_id: mock.Mock, mock_sagemaker_timestamp: mock.Mock, + mock_jumpstart_model_factory_logger: mock.Mock, ): mock_model_deploy.return_value = default_predictor @@ -70,9 +72,14 @@ def test_non_prepacked( mock_session.return_value = sagemaker_session + mock_jumpstart_model_factory_logger.info.reset_mock() model = JumpStartModel( model_id=model_id, ) + mock_jumpstart_model_factory_logger.info.assert_called_once_with( + "No " "instance type selected for inference hosting endpoint. " "Defaulting to %s.", + "ml.p2.xlarge", + ) mock_model_init.assert_called_once_with( image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/" @@ -96,7 +103,9 @@ def test_non_prepacked( name="blahblahblah-7777", ) + mock_jumpstart_model_factory_logger.info.reset_mock() model.deploy() + mock_jumpstart_model_factory_logger.info.assert_not_called() mock_model_deploy.assert_called_once_with( initial_instance_count=1, From 5df5412d7aeaf6cdac94474863a5291d2f78dc18 Mon Sep 17 00:00:00 2001 From: Evan Kravitz Date: Mon, 13 Nov 2023 17:38:44 +0000 Subject: [PATCH 2/2] chore: improve improve kwarg name --- src/sagemaker/jumpstart/factory/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 8e891ec818..1971845ebe 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -169,7 +169,7 @@ def _add_vulnerable_and_deprecated_status_to_kwargs( def _add_instance_type_to_kwargs( - kwargs: JumpStartModelInitKwargs, disable_logging: bool = False + kwargs: JumpStartModelInitKwargs, disable_instance_type_logging: bool = False ) -> JumpStartModelInitKwargs: """Sets instance type based on default or override, returns full kwargs.""" @@ -186,7 +186,7 @@ def _add_instance_type_to_kwargs( training_instance_type=kwargs.training_instance_type, ) - if not disable_logging and orig_instance_type is None: + if not disable_instance_type_logging and orig_instance_type is None: JUMPSTART_LOGGER.info( "No instance type selected for inference hosting endpoint. Defaulting to %s.", kwargs.instance_type, @@ -687,7 +687,7 @@ def get_init_kwargs( model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs) model_init_kwargs = _add_instance_type_to_kwargs( - kwargs=model_init_kwargs, disable_logging=disable_instance_type_logging + kwargs=model_init_kwargs, disable_instance_type_logging=disable_instance_type_logging ) model_init_kwargs = _add_image_uri_to_kwargs(kwargs=model_init_kwargs)