Skip to content

Commit cae397d

Browse files
authored
make AutoML.deploy use self.sagemaker_session by default (#1311)
1 parent ad22096 commit cae397d

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

src/sagemaker/automl/automl.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,7 @@ def deploy(
216216
created from it.
217217
sagemaker_session (sagemaker.session.Session): A SageMaker Session
218218
object, used for SageMaker interactions (default: None). If not
219-
specified, one is created using the default AWS configuration
220-
chain.
219+
specified, the one originally associated with the ``AutoML`` instance is used.
221220
name (str): The pipeline model name. If None, a default model name will
222221
be selected on each ``deploy``.
223222
endpoint_name (str): The name of the endpoint to create (default:
@@ -248,6 +247,8 @@ def deploy(
248247
If ``predictor_cls`` is specified, the invocation of ``self.predictor_cls`` on
249248
the created endpoint name. Otherwise, ``None``.
250249
"""
250+
sagemaker_session = sagemaker_session or self.sagemaker_session
251+
251252
if candidate is None:
252253
candidate_dict = self.best_candidate()
253254
candidate = CandidateEstimator(candidate_dict, sagemaker_session=sagemaker_session)

tests/unit/sagemaker/automl/test_auto_ml.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -477,16 +477,19 @@ def test_deploy(sagemaker_session, candidate_mock):
477477
)
478478

479479

480-
def test_deploy_optional_args(sagemaker_session, candidate_mock):
480+
@patch("sagemaker.automl.automl.CandidateEstimator")
481+
def test_deploy_optional_args(candidate_estimator, sagemaker_session, candidate_mock):
482+
candidate_estimator.return_value = candidate_mock
483+
481484
auto_ml = AutoML(
482485
role=ROLE, target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_session
483486
)
484-
auto_ml.best_candidate = Mock(name="best_candidate", return_value=CANDIDATE_DICT)
485487
auto_ml._deploy_inference_pipeline = Mock("_deploy_inference_pipeline", return_value=None)
486488

487489
auto_ml.deploy(
488490
initial_instance_count=INSTANCE_COUNT,
489491
instance_type=INSTANCE_TYPE,
492+
candidate=CANDIDATE_DICT,
490493
sagemaker_session=sagemaker_session,
491494
name=JOB_NAME,
492495
endpoint_name=JOB_NAME,
@@ -515,6 +518,8 @@ def test_deploy_optional_args(sagemaker_session, candidate_mock):
515518
predictor_cls=RealTimePredictor,
516519
)
517520

521+
candidate_estimator.assert_called_with(CANDIDATE_DICT, sagemaker_session=sagemaker_session)
522+
518523

519524
def test_candidate_estimator_get_steps(sagemaker_session):
520525
candidate_estimator = CandidateEstimator(CANDIDATE_DICT, sagemaker_session=sagemaker_session)

0 commit comments

Comments
 (0)