Skip to content

Commit 7cd8c1e

Browse files
authored
Merge branch 'master' into processor-docs
2 parents 2af4794 + af7f75a commit 7cd8c1e

File tree

6 files changed

+40
-12
lines changed

6 files changed

+40
-12
lines changed

src/sagemaker/automl/automl.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""A class for SageMaker AutoML Jobs."""
1414
from __future__ import absolute_import
1515

16+
import logging
1617
from six import string_types
1718

1819
from sagemaker import Model, PipelineModel
@@ -21,6 +22,8 @@
2122
from sagemaker.session import Session
2223
from sagemaker.utils import name_from_base
2324

25+
logger = logging.getLogger("sagemaker")
26+
2427

2528
class AutoML(object):
2629
"""A class for creating and interacting with SageMaker AutoML jobs
@@ -78,16 +81,15 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None):
7881
is stored. Or an AutoMLInput object. If a local path is provided, the dataset will
7982
be uploaded to an S3 location.
8083
wait (bool): Whether the call should wait until the job completes (default: True).
81-
logs (bool): Whether to show the logs produced by the job.
82-
Only meaningful when wait is True (default: True).
84+
logs (bool): Whether to show the logs produced by the job. Only meaningful when wait
85+
is True (default: True). if ``wait`` is False, ``logs`` will be set to False as
86+
well.
8387
job_name (str): Training job name. If not specified, the estimator generates
8488
a default job name, based on the training image name and current timestamp.
8589
"""
86-
if logs and not wait:
87-
raise ValueError(
88-
"""Logs can only be shown if wait is set to True.
89-
Please either set wait to True or set logs to False."""
90-
)
90+
if not wait and logs:
91+
logs = False
92+
logger.warning("Setting logs to False. logs is only meaningful when wait is True.")
9193

9294
# upload data for users if provided local path
9395
# validations are done in _Job._format_inputs_to_input_config

src/sagemaker/fw_utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@
4949
"Please set the argument \"py_version='py3'\" to use the Python 3 {framework} image."
5050
)
5151
PARAMETER_SERVER_MULTI_GPU_WARNING = (
52-
"You have selected a multi-GPU training instance type. "
53-
"You have also enabled parameter server for distributed training. "
52+
"If you have selected a multi-GPU training instance type, "
53+
"and have also enabled parameter server for distributed training. "
5454
"Distributed training with the default parameter server configuration will not "
5555
"fully leverage all GPU cores; the parameter server will be configured to run "
5656
"only one worker per host regardless of the number of GPUs."
@@ -625,9 +625,9 @@ def warn_if_parameter_server_with_multi_gpu(training_instance_type, distribution
625625
return
626626

627627
is_multi_gpu_instance = (
628-
training_instance_type.split(".")[1].startswith("p")
629-
and training_instance_type not in SINGLE_GPU_INSTANCE_TYPES
630-
)
628+
training_instance_type == "local_gpu"
629+
or training_instance_type.split(".")[1].startswith("p")
630+
) and training_instance_type not in SINGLE_GPU_INSTANCE_TYPES
631631

632632
ps_enabled = "parameter_server" in distributions and distributions["parameter_server"].get(
633633
"enabled", False

src/sagemaker/workflow/airflow.py

+3
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ def training_base_config(estimator, inputs=None, job_name=None, mini_batch_size=
181181
if job_config["vpc_config"] is not None:
182182
train_config["VpcConfig"] = job_config["vpc_config"]
183183

184+
if estimator.train_use_spot_instances:
185+
train_config["EnableManagedSpotTraining"] = True
186+
184187
if estimator.hyperparameters() is not None:
185188
hyperparameters = {str(k): str(v) for (k, v) in estimator.hyperparameters().items()}
186189

tests/unit/sagemaker/automl/test_auto_ml.py

+11
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,17 @@ def test_auto_ml_only_one_of_problem_type_and_job_objective_provided(sagemaker_s
294294
)
295295

296296

297+
@patch("sagemaker.automl.automl.AutoMLJob.start_new")
298+
def test_auto_ml_fit_set_logs_to_false(start_new, sagemaker_session, caplog):
299+
auto_ml = AutoML(
300+
role=ROLE, target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_session
301+
)
302+
inputs = DEFAULT_S3_INPUT_DATA
303+
auto_ml.fit(inputs, job_name=JOB_NAME, wait=False, logs=True)
304+
start_new.wait.assert_not_called()
305+
assert "Setting logs to False. logs is only meaningful when wait is True." in caplog.text
306+
307+
297308
def test_auto_ml_additional_optional_params(sagemaker_session):
298309
auto_ml = AutoML(
299310
role=ROLE,

tests/unit/test_airflow.py

+2
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def test_byo_training_config_all_args(sagemaker_session):
105105
model_uri="{{ model_uri }}",
106106
model_channel_name="{{ model_chanel }}",
107107
sagemaker_session=sagemaker_session,
108+
train_use_spot_instances=True,
108109
)
109110

110111
byo.set_hyperparameters(epochs=32, feature_dim=1024, mini_batch_size=256)
@@ -155,6 +156,7 @@ def test_byo_training_config_all_args(sagemaker_session):
155156
"Subnets": ["{{ subnet }}"],
156157
"SecurityGroupIds": ["{{ security_group_ids }}"],
157158
},
159+
"EnableManagedSpotTraining": True,
158160
"HyperParameters": {"epochs": "32", "feature_dim": "1024", "mini_batch_size": "256"},
159161
"Tags": [{"{{ key }}": "{{ value }}"}],
160162
}

tests/unit/test_fw_utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -1272,3 +1272,13 @@ def test_warn_if_parameter_server_with_multi_gpu(caplog):
12721272
training_instance_type=train_instance_type, distributions=distributions
12731273
)
12741274
assert fw_utils.PARAMETER_SERVER_MULTI_GPU_WARNING in caplog.text
1275+
1276+
1277+
def test_warn_if_parameter_server_with_local_multi_gpu(caplog):
1278+
train_instance_type = "local_gpu"
1279+
distributions = {"parameter_server": {"enabled": True}}
1280+
1281+
fw_utils.warn_if_parameter_server_with_multi_gpu(
1282+
training_instance_type=train_instance_type, distributions=distributions
1283+
)
1284+
assert fw_utils.PARAMETER_SERVER_MULTI_GPU_WARNING in caplog.text

0 commit comments

Comments
 (0)