Skip to content

fix: set logs to False if wait is False #1585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 17, 2020
15 changes: 8 additions & 7 deletions src/sagemaker/automl/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""A class for SageMaker AutoML Jobs."""
from __future__ import absolute_import

import logging
from six import string_types

from sagemaker import Model, PipelineModel
Expand All @@ -21,6 +22,8 @@
from sagemaker.session import Session
from sagemaker.utils import name_from_base

logger = logging.getLogger("sagemaker")


class AutoML(object):
"""A class for creating and interacting with SageMaker AutoML jobs
Expand Down Expand Up @@ -78,16 +81,14 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None):
is stored. Or an AutoMLInput object. If a local path is provided, the dataset will
be uploaded to an S3 location.
wait (bool): Whether the call should wait until the job completes (default: True).
logs (bool): Whether to show the logs produced by the job.
Only meaningful when wait is True (default: True).
logs (bool): Whether to show the logs produced by the job. Only meaningful when wait
is True (default: True). if `wait` is False, `logs` will be set to False as well.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: need two backticks for formatting

job_name (str): Training job name. If not specified, the estimator generates
a default job name, based on the training image name and current timestamp.
"""
if logs and not wait:
raise ValueError(
"""Logs can only be shown if wait is set to True.
Please either set wait to True or set logs to False."""
)
if not wait and logs:
logs = False
logger.warning("logs will be set to False. logs is only meaningful when wait is True.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/logs will be set/Setting logs


# upload data for users if provided local path
# validations are done in _Job._format_inputs_to_input_config
Expand Down
5 changes: 3 additions & 2 deletions tests/unit/sagemaker/automl/test_auto_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def test_auto_ml_only_one_of_problem_type_and_job_objective_provided(sagemaker_s
)


def test_auto_ml_additional_optional_params(sagemaker_session):
def test_auto_ml_additional_optional_params(sagemaker_session, caplog):
auto_ml = AutoML(
role=ROLE,
target_attribute_name=TARGET_ATTRIBUTE_NAME,
Expand All @@ -314,7 +314,7 @@ def test_auto_ml_additional_optional_params(sagemaker_session):
tags=TAGS,
)
inputs = DEFAULT_S3_INPUT_DATA
auto_ml.fit(inputs, job_name=JOB_NAME)
auto_ml.fit(inputs, job_name=JOB_NAME, wait=False, logs=True)
sagemaker_session.auto_ml.assert_called_once()
_, args = sagemaker_session.auto_ml.call_args

Expand Down Expand Up @@ -348,6 +348,7 @@ def test_auto_ml_additional_optional_params(sagemaker_session):
"generate_candidate_definitions_only": GENERATE_CANDIDATE_DEFINITIONS_ONLY,
"tags": TAGS,
}
assert "logs will be set to False. logs is only meaningful when wait is True." in caplog.text
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in addition to testing for the log output, you could patch("sagemaker.automl.AutoMLJob.start_new") and then call .wait.assert_not_called() on the resulting mock

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a new test instead to test warning messages and assert_not_called()



@patch("time.strftime", return_value=TIMESTAMP)
Expand Down