diff --git a/src/sagemaker/automl/automl.py b/src/sagemaker/automl/automl.py index 247a5ae71e..ed9805f75e 100644 --- a/src/sagemaker/automl/automl.py +++ b/src/sagemaker/automl/automl.py @@ -216,8 +216,7 @@ def deploy( created from it. sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions (default: None). If not - specified, one is created using the default AWS configuration - chain. + specified, the one originally associated with the ``AutoML`` instance is used. name (str): The pipeline model name. If None, a default model name will be selected on each ``deploy``. endpoint_name (str): The name of the endpoint to create (default: @@ -248,6 +247,8 @@ def deploy( If ``predictor_cls`` is specified, the invocation of ``self.predictor_cls`` on the created endpoint name. Otherwise, ``None``. """ + sagemaker_session = sagemaker_session or self.sagemaker_session + if candidate is None: candidate_dict = self.best_candidate() candidate = CandidateEstimator(candidate_dict, sagemaker_session=sagemaker_session) diff --git a/tests/unit/sagemaker/automl/test_auto_ml.py b/tests/unit/sagemaker/automl/test_auto_ml.py index d4fd5fe236..c6b240e2b2 100644 --- a/tests/unit/sagemaker/automl/test_auto_ml.py +++ b/tests/unit/sagemaker/automl/test_auto_ml.py @@ -477,16 +477,19 @@ def test_deploy(sagemaker_session, candidate_mock): ) -def test_deploy_optional_args(sagemaker_session, candidate_mock): +@patch("sagemaker.automl.automl.CandidateEstimator") +def test_deploy_optional_args(candidate_estimator, sagemaker_session, candidate_mock): + candidate_estimator.return_value = candidate_mock + auto_ml = AutoML( role=ROLE, target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_session ) - auto_ml.best_candidate = Mock(name="best_candidate", return_value=CANDIDATE_DICT) auto_ml._deploy_inference_pipeline = Mock("_deploy_inference_pipeline", return_value=None) auto_ml.deploy( initial_instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE, + candidate=CANDIDATE_DICT, sagemaker_session=sagemaker_session, name=JOB_NAME, endpoint_name=JOB_NAME, @@ -515,6 +518,8 @@ def test_deploy_optional_args(sagemaker_session, candidate_mock): predictor_cls=RealTimePredictor, ) + candidate_estimator.assert_called_with(CANDIDATE_DICT, sagemaker_session=sagemaker_session) + def test_candidate_estimator_get_steps(sagemaker_session): candidate_estimator = CandidateEstimator(CANDIDATE_DICT, sagemaker_session=sagemaker_session)