Skip to content

Commit ed02f43

Browse files
chuyang-dengknakad
authored andcommitted
fix: correct AutoML imports and expose current_job_name (#300)
1 parent 87bced5 commit ed02f43

File tree

5 files changed

+23
-22
lines changed

5 files changed

+23
-22
lines changed

src/sagemaker/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,7 @@
5757
from sagemaker.session import s3_input # noqa: F401
5858
from sagemaker.session import get_execution_role # noqa: F401
5959

60+
from sagemaker.automl.automl import AutoML, AutoMLJob, AutoMLInput # noqa: F401
61+
from sagemaker.automl.candidate_estimator import CandidateEstimator, CandidateStep # noqa: F401
62+
6063
__version__ = pkg_resources.require("sagemaker")[0].version

src/sagemaker/automl/automl.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363
self.generate_candidate_definitions_only = generate_candidate_definitions_only
6464
self.tags = tags
6565

66-
self._current_job_name = None
66+
self.current_job_name = None
6767
self._auto_ml_job_desc = None
6868
self._best_candidate = None
6969
self.sagemaker_session = sagemaker_session or Session()
@@ -111,7 +111,7 @@ def describe_auto_ml_job(self, job_name=None):
111111
dict: A dictionary response with the AutoML Job description.
112112
"""
113113
if job_name is None:
114-
job_name = self._current_job_name
114+
job_name = self.current_job_name
115115
self._auto_ml_job_desc = self.sagemaker_session.describe_auto_ml_job(job_name)
116116
return self._auto_ml_job_desc
117117

@@ -128,7 +128,7 @@ def best_candidate(self, job_name=None):
128128
return self._best_candidate
129129

130130
if job_name is None:
131-
job_name = self._current_job_name
131+
job_name = self.current_job_name
132132
if self._auto_ml_job_desc is None:
133133
self._auto_ml_job_desc = self.sagemaker_session.describe_auto_ml_job(job_name)
134134
elif self._auto_ml_job_desc["AutoMLJobName"] != job_name:
@@ -168,7 +168,7 @@ def list_candidates(
168168
list: A list of dictionaries with candidates information
169169
"""
170170
if job_name is None:
171-
job_name = self._current_job_name
171+
job_name = self.current_job_name
172172

173173
list_candidates_args = {"job_name": job_name}
174174

@@ -249,6 +249,7 @@ def deploy(
249249
candidate = CandidateEstimator(candidate, sagemaker_session=sagemaker_session)
250250

251251
inference_containers = candidate.containers
252+
endpoint_name = endpoint_name or self.current_job_name
252253

253254
return self._deploy_inference_pipeline(
254255
inference_containers,
@@ -373,14 +374,14 @@ def _prepare_for_auto_ml_job(self, job_name=None):
373374
created from base_job_name or "sagemaker-auto-ml".
374375
"""
375376
if job_name is not None:
376-
self._current_job_name = job_name
377+
self.current_job_name = job_name
377378
else:
378379
if self.base_job_name:
379380
base_name = self.base_job_name
380381
else:
381382
base_name = "sagemaker-auto-ml"
382383
# CreateAutoMLJob API validates that member length less than or equal to 32
383-
self._current_job_name = name_from_base(base_name, max_length=32)
384+
self.current_job_name = name_from_base(base_name, max_length=32)
384385

385386
if self.output_path is None:
386387
self.output_path = "s3://{}/".format(self.sagemaker_session.default_bucket())
@@ -426,6 +427,7 @@ class AutoMLJob(_Job):
426427

427428
def __init__(self, sagemaker_session, job_name, inputs):
428429
self.inputs = inputs
430+
self.job_name = job_name
429431
super(AutoMLJob, self).__init__(sagemaker_session=sagemaker_session, job_name=job_name)
430432

431433
@classmethod
@@ -444,13 +446,13 @@ def start_new(cls, auto_ml, inputs):
444446
"""
445447
config = cls._load_config(inputs, auto_ml)
446448
auto_ml_args = config.copy()
447-
auto_ml_args["job_name"] = auto_ml._current_job_name
449+
auto_ml_args["job_name"] = auto_ml.current_job_name
448450
auto_ml_args["problem_type"] = auto_ml.problem_type
449451
auto_ml_args["job_objective"] = auto_ml.job_objective
450452
auto_ml_args["tags"] = auto_ml.tags
451453

452454
auto_ml.sagemaker_session.auto_ml(**auto_ml_args)
453-
return cls(auto_ml.sagemaker_session, auto_ml._current_job_name, inputs)
455+
return cls(auto_ml.sagemaker_session, auto_ml.current_job_name, inputs)
454456

455457
@classmethod
456458
def _load_config(cls, inputs, auto_ml, expand_role=True, validate_uri=True):

src/sagemaker/session.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1326,12 +1326,12 @@ def wait_for_auto_ml_job(self, job, poll=5):
13261326
"""Wait for an Amazon SageMaker AutoML job to complete.
13271327
13281328
Args:
1329-
job (str): Name of the transform job to wait for.
1329+
job (str): Name of the auto ml job to wait for.
13301330
poll (int): Polling interval in seconds (default: 5).
13311331
Returns:
1332-
(dict): Return value from the ``DescribeTransformJob`` API.
1332+
(dict): Return value from the ``DescribeAutoMLJob`` API.
13331333
Raises:
1334-
exceptions.UnexpectedStatusException: If the transform job fails.
1334+
exceptions.UnexpectedStatusException: If the auto ml job fails.
13351335
"""
13361336
desc = _wait_until(lambda: _auto_ml_job_status(self.sagemaker_client, job), poll)
13371337
self._check_job_status(job, desc, "AutoMLJobStatus")

tests/integ/test_auto_ml.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,14 @@
1616
import time
1717

1818
import pytest
19+
from sagemaker import AutoML, CandidateEstimator, AutoMLInput
1920

20-
from sagemaker.automl.automl import AutoML, AutoMLInput
21-
from sagemaker.automl.candidate_estimator import CandidateEstimator
2221
from sagemaker.exceptions import UnexpectedStatusException
2322
from sagemaker.utils import unique_name_from_base
2423
from tests.integ import DATA_DIR, AUTO_ML_DEFAULT_TIMEMOUT_MINUTES
2524
from tests.integ.timeout import timeout
2625

2726
DEV_ACCOUNT = 142577830533
28-
# ROLE = "arn:aws:iam::142577830533:role/SageMakerRole"
2927
ROLE = "SageMakerRole"
3028
PREFIX = "sagemaker/beta-automl-xgboost"
3129
HOSTING_INSTANCE_TYPE = "ml.c4.xlarge"

tests/unit/sagemaker/automl/test_auto_ml.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414

1515
import pytest
1616
from mock import Mock, patch
17-
18-
from sagemaker.automl.automl import AutoML, AutoMLJob, AutoMLInput
19-
from sagemaker.automl.candidate_estimator import CandidateEstimator
17+
from sagemaker import AutoML, AutoMLJob, AutoMLInput, CandidateEstimator
2018

2119
MODEL_DATA = "s3://bucket/model.tar.gz"
2220
MODEL_IMAGE = "mi"
@@ -363,10 +361,10 @@ def test_list_candidates_default(sagemaker_session):
363361
auto_ml = AutoML(
364362
role=ROLE, target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_session
365363
)
366-
auto_ml._current_job_name = "current_job_name"
364+
auto_ml.current_job_name = "current_job_name"
367365
auto_ml.list_candidates()
368366
sagemaker_session.list_candidates.assert_called_once()
369-
sagemaker_session.list_candidates.assert_called_with(job_name=auto_ml._current_job_name)
367+
sagemaker_session.list_candidates.assert_called_with(job_name=auto_ml.current_job_name)
370368

371369

372370
def test_list_candidates_with_optional_args(sagemaker_session):
@@ -409,7 +407,7 @@ def test_best_candidate_default_job_name(sagemaker_session):
409407
auto_ml = AutoML(
410408
role=ROLE, target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_session
411409
)
412-
auto_ml._current_job_name = JOB_NAME
410+
auto_ml.current_job_name = JOB_NAME
413411
auto_ml._auto_ml_job_desc = AUTO_ML_DESC
414412
best_candidate = auto_ml.best_candidate()
415413
sagemaker_session.describe_auto_ml_job.assert_not_called()
@@ -420,7 +418,7 @@ def test_best_candidate_job_no_desc(sagemaker_session):
420418
auto_ml = AutoML(
421419
role=ROLE, target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_session
422420
)
423-
auto_ml._current_job_name = JOB_NAME
421+
auto_ml.current_job_name = JOB_NAME
424422
best_candidate = auto_ml.best_candidate()
425423
sagemaker_session.describe_auto_ml_job.assert_called_once()
426424
sagemaker_session.describe_auto_ml_job.assert_called_with(JOB_NAME)
@@ -441,7 +439,7 @@ def test_best_candidate_job_name_not_match(sagemaker_session):
441439
auto_ml = AutoML(
442440
role=ROLE, target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_session
443441
)
444-
auto_ml._current_job_name = JOB_NAME
442+
auto_ml.current_job_name = JOB_NAME
445443
auto_ml._auto_ml_job_desc = AUTO_ML_DESC
446444
best_candidate = auto_ml.best_candidate(job_name=JOB_NAME_2)
447445
sagemaker_session.describe_auto_ml_job.assert_called_once()

0 commit comments

Comments
 (0)