Skip to content

Commit 789bed9

Browse files
chuyang-dengknakad
authored andcommitted
change: AutoML improvements (aws#281)
* change: AutoML improvements * AutoML improvements * address comments
1 parent 516b9c5 commit 789bed9

File tree

7 files changed

+137
-33
lines changed

7 files changed

+137
-33
lines changed

buildspec.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ phases:
1919
- start_time=`date +%s`
2020
- |
2121
if has-matching-changes "tests/" "src/*.py" "setup.py" "setup.cfg" "buildspec.yml"; then
22-
tox -e py36 -- tests/integ -m "not local_mode" -n 48 --reruns 3 --reruns-delay 5 --durations 50 --boto-config '{"region_name": "us-west-2"}'
22+
tox -e py36 -- tests/integ -m "not local_mode" -n 48 --reruns 3 --reruns-delay 5 --durations 50 --boto-config '{"region_name": "us-east-2"}'
2323
fi
2424
- ./ci-scripts/displaytime.sh 'py36 tests/integ' $start_time
2525

src/sagemaker/automl/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of

src/sagemaker/automl/automl.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,9 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None):
7474
"""Create an AutoML Job with the input dataset.
7575
7676
Args:
77-
inputs (str or list[str]): Local path or S3 Uri where the training data is stored. If a
78-
local path is provided, the dataset will be uploaded to an S3 location.
77+
inputs (str or list[str] or AutoMLInput): Local path or S3 Uri where the training data
78+
is stored. Or an AutoMLInput object. If a local path is provided, the dataset will
79+
be uploaded to an S3 location.
7980
wait (bool): Whether the call should wait until the job completes (default: True).
8081
logs (bool): Whether to show the logs produced by the job.
8182
Only meaningful when wait is True (default: True).
@@ -95,7 +96,7 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None):
9596
inputs = self.sagemaker_session.upload_data(inputs, key_prefix="auto-ml-input-data")
9697
self._prepare_for_auto_ml_job(job_name=job_name)
9798

98-
self.latest_auto_ml_job = _AutoMLJob.start_new(self, inputs) # pylint: disable=W0201
99+
self.latest_auto_ml_job = AutoMLJob.start_new(self, inputs) # pylint: disable=W0201
99100
if wait:
100101
self.latest_auto_ml_job.wait(logs=logs)
101102

@@ -385,9 +386,48 @@ def _prepare_for_auto_ml_job(self, job_name=None):
385386
self.output_path = "s3://{}/".format(self.sagemaker_session.default_bucket())
386387

387388

388-
class _AutoMLJob(_Job):
389+
class AutoMLInput(object):
390+
"""Accepts parameters that specify an S3 input for an auto ml job and provides
391+
a method to turn those parameters into a dictionary."""
392+
393+
def __init__(self, inputs, target_attribute_name, compression=None):
394+
"""Convert an S3 Uri or a list of S3 Uri to an AutoMLInput object.
395+
396+
:param inputs (str, list[str]): a string or a list of string that points to (a)
397+
S3 location(s) where input data is stored.
398+
:param target_attribute_name (str): the target attribute name for regression
399+
or classification.
400+
:param compression (str): if training data is compressed, the compression type.
401+
The default value is None.
402+
"""
403+
self.inputs = inputs
404+
self.target_attribute_name = target_attribute_name
405+
self.compression = compression
406+
407+
def to_request_dict(self):
408+
"""Generates a request dictionary using the parameters provided to the class."""
409+
# Create the request dictionary.
410+
auto_ml_input = []
411+
if isinstance(self.inputs, string_types):
412+
self.inputs = [self.inputs]
413+
for entry in self.inputs:
414+
input_entry = {
415+
"DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": entry}},
416+
"TargetAttributeName": self.target_attribute_name,
417+
}
418+
if self.compression is not None:
419+
input_entry["CompressionType"] = self.compression
420+
auto_ml_input.append(input_entry)
421+
return auto_ml_input
422+
423+
424+
class AutoMLJob(_Job):
389425
"""A class for interacting with CreateAutoMLJob API."""
390426

427+
def __init__(self, sagemaker_session, job_name, inputs):
428+
self.inputs = inputs
429+
super(AutoMLJob, self).__init__(sagemaker_session=sagemaker_session, job_name=job_name)
430+
391431
@classmethod
392432
def start_new(cls, auto_ml, inputs):
393433
"""Create a new Amazon SageMaker AutoML job from auto_ml.
@@ -399,7 +439,7 @@ def start_new(cls, auto_ml, inputs):
399439
:meth:`~sagemaker.automl.AutoML.fit`.
400440
401441
Returns:
402-
sagemaker.automl._AutoMLJob: Constructed object that captures
442+
sagemaker.automl.AutoMLJob: Constructed object that captures
403443
all information about the started AutoML job.
404444
"""
405445
config = cls._load_config(inputs, auto_ml)
@@ -410,7 +450,7 @@ def start_new(cls, auto_ml, inputs):
410450
auto_ml_args["tags"] = auto_ml.tags
411451

412452
auto_ml.sagemaker_session.auto_ml(**auto_ml_args)
413-
return cls(auto_ml.sagemaker_session, auto_ml._current_job_name)
453+
return cls(auto_ml.sagemaker_session, auto_ml._current_job_name, inputs)
414454

415455
@classmethod
416456
def _load_config(cls, inputs, auto_ml, expand_role=True, validate_uri=True):
@@ -432,9 +472,12 @@ def _load_config(cls, inputs, auto_ml, expand_role=True, validate_uri=True):
432472
# InputDataConfig
433473
# OutputConfig
434474

435-
input_config = cls._format_inputs_to_input_config(
436-
inputs, validate_uri, auto_ml.compression_type, auto_ml.target_attribute_name
437-
)
475+
if isinstance(inputs, AutoMLInput):
476+
input_config = inputs.to_request_dict()
477+
else:
478+
input_config = cls._format_inputs_to_input_config(
479+
inputs, validate_uri, auto_ml.compression_type, auto_ml.target_attribute_name
480+
)
438481
output_config = _Job._prepare_output_config(auto_ml.output_path, auto_ml.output_kms_key)
439482

440483
role = auto_ml.sagemaker_session.expand_role(auto_ml.role) if expand_role else auto_ml.role
@@ -486,7 +529,9 @@ def _format_inputs_to_input_config(
486529
return None
487530

488531
channels = []
489-
if isinstance(inputs, string_types):
532+
if isinstance(inputs, AutoMLInput):
533+
channels.append(inputs.to_request_dict())
534+
elif isinstance(inputs, string_types):
490535
channel = _Job._format_string_uri_input(
491536
inputs,
492537
validate_uri,
@@ -540,6 +585,10 @@ def _prepare_auto_ml_stop_condition(
540585

541586
return stopping_condition
542587

588+
def describe(self):
589+
"""Prints out a response from the DescribeAutoMLJob API call."""
590+
return self.sagemaker_session.describe_auto_ml_job(self.job_name)
591+
543592
def wait(self, logs=True):
544593
"""Wait for the AutoML job to finish.
545594
Args:

src/sagemaker/automl/candidate_estimator.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -56,24 +56,14 @@ def get_steps(self):
5656
)
5757

5858
inputs = training_job["InputDataConfig"]
59-
candidate_step = {
60-
"name": step_name,
61-
"inputs": inputs,
62-
"type": step_type,
63-
"desc": training_job,
64-
}
59+
candidate_step = CandidateStep(step_name, inputs, step_type, training_job)
6560
candidate_steps.append(candidate_step)
6661
elif step_type == "TransformJob":
6762
transform_job = self.sagemaker_session.sagemaker_client.describe_transform_job(
6863
TransformJobName=step_name
6964
)
7065
inputs = transform_job["TransformInput"]
71-
candidate_step = {
72-
"name": step_name,
73-
"inputs": inputs,
74-
"type": step_type,
75-
"desc": transform_job,
76-
}
66+
candidate_step = CandidateStep(step_name, inputs, step_type, transform_job)
7767
candidate_steps.append(candidate_step)
7868
return candidate_steps
7969

@@ -313,3 +303,33 @@ def _process_steps(self, steps):
313303
step_type = step["CandidateStepType"].split("::")[2]
314304
processed_steps.append({"name": step_name, "type": step_type})
315305
return processed_steps
306+
307+
308+
class CandidateStep(object):
309+
"""A class that maintains an AutoML Candidate step's name, inputs, type, and description."""
310+
311+
def __init__(self, name, inputs, step_type, description):
312+
self._name = name
313+
self._inputs = inputs
314+
self._type = step_type
315+
self._description = description
316+
317+
@property
318+
def name(self):
319+
"""Name of the candidate step -> (str)"""
320+
return self._name
321+
322+
@property
323+
def inputs(self):
324+
"""Inputs of the candidate step -> (dict)"""
325+
return self._inputs
326+
327+
@property
328+
def type(self):
329+
"""Type of the candidate step, Training or Transform -> (str)"""
330+
return self._type
331+
332+
@property
333+
def description(self):
334+
"""Description of candidate step job -> (dict)"""
335+
return self._description

src/sagemaker/job.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,14 @@ def _format_string_uri_input(
195195
target_attribute_name:
196196
"""
197197
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"):
198-
return s3_input(
198+
s3_input_result = s3_input(
199199
uri_input,
200200
content_type=content_type,
201201
input_mode=input_mode,
202202
compression=compression,
203203
target_attribute_name=target_attribute_name,
204204
)
205+
return s3_input_result
205206
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("file://"):
206207
return file_input(uri_input)
207208
if isinstance(uri_input, str) and validate_uri:
@@ -210,13 +211,14 @@ def _format_string_uri_input(
210211
'"file://"'.format(uri_input)
211212
)
212213
if isinstance(uri_input, str):
213-
return s3_input(
214+
s3_input_result = s3_input(
214215
uri_input,
215216
content_type=content_type,
216217
input_mode=input_mode,
217218
compression=compression,
218219
target_attribute_name=target_attribute_name,
219220
)
221+
return s3_input_result
220222
if isinstance(uri_input, (s3_input, file_input, FileSystemInput)):
221223
return uri_input
222224

tests/integ/test_auto_ml.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pytest
1919

20-
from sagemaker.automl.automl import AutoML
20+
from sagemaker.automl.automl import AutoML, AutoMLInput
2121
from sagemaker.automl.candidate_estimator import CandidateEstimator
2222
from sagemaker.exceptions import UnexpectedStatusException
2323
from sagemaker.utils import unique_name_from_base
@@ -47,7 +47,7 @@
4747
"DataSource": {
4848
"S3DataSource": {
4949
"S3DataType": "S3Prefix",
50-
"S3Uri": "s3://sagemaker-us-west-2-{}/{}/input/iris_training.csv".format(
50+
"S3Uri": "s3://sagemaker-us-east-2-{}/{}/input/iris_training.csv".format(
5151
DEV_ACCOUNT, PREFIX
5252
),
5353
}
@@ -60,7 +60,7 @@
6060
"SecurityConfig": {"EnableInterContainerTrafficEncryption": False},
6161
}
6262
EXPECTED_DEFAULT_OUTPUT_CONFIG = {
63-
"S3OutputPath": "s3://sagemaker-us-west-2-{}/".format(DEV_ACCOUNT)
63+
"S3OutputPath": "s3://sagemaker-us-east-2-{}/".format(DEV_ACCOUNT)
6464
}
6565

6666

@@ -90,8 +90,21 @@ def test_auto_ml_fit_local_input(sagemaker_session):
9090
auto_ml.fit(inputs)
9191

9292

93+
def test_auto_ml_input_object_fit(sagemaker_session):
94+
auto_ml = AutoML(
95+
role=ROLE,
96+
target_attribute_name=TARGET_ATTRIBUTE_NAME,
97+
sagemaker_session=sagemaker_session,
98+
max_candidates=1,
99+
)
100+
s3_input = sagemaker_session.upload_data(path=TRAINING_DATA, key_prefix=PREFIX + "/input")
101+
inputs = AutoMLInput(inputs=s3_input, target_attribute_name=TARGET_ATTRIBUTE_NAME)
102+
with timeout(minutes=AUTO_ML_DEFAULT_TIMEMOUT_MINUTES):
103+
auto_ml.fit(inputs)
104+
105+
93106
def test_auto_ml_fit_optional_args(sagemaker_session):
94-
output_path = "s3://sagemaker-us-west-2-{}/{}".format(DEV_ACCOUNT, "specified_ouput_path")
107+
output_path = "s3://sagemaker-us-east-2-{}/{}".format(DEV_ACCOUNT, "specified_ouput_path")
95108
problem_type = "MulticlassClassification"
96109
job_objective = {"MetricName": "Accuracy"}
97110
auto_ml = AutoML(

tests/unit/sagemaker/automl/test_auto_ml.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import pytest
1616
from mock import Mock, patch
1717

18-
from sagemaker.automl.automl import AutoML, _AutoMLJob
18+
from sagemaker.automl.automl import AutoML, AutoMLJob, AutoMLInput
1919
from sagemaker.automl.candidate_estimator import CandidateEstimator
2020

2121
MODEL_DATA = "s3://bucket/model.tar.gz"
@@ -189,7 +189,7 @@ def test_auto_ml_default_channel_name(sagemaker_session):
189189
role=ROLE, target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_session
190190
)
191191
inputs = DEFAULT_S3_INPUT_DATA
192-
_AutoMLJob.start_new(auto_ml, inputs)
192+
AutoMLJob.start_new(auto_ml, inputs)
193193
sagemaker_session.auto_ml.assert_called_once()
194194
_, args = sagemaker_session.auto_ml.call_args
195195
assert args["input_config"] == [
@@ -210,7 +210,7 @@ def test_auto_ml_invalid_input_data_format(sagemaker_session):
210210

211211
expected_error_msg = "Cannot format input {}. Expecting one of str or list of strings."
212212
with pytest.raises(ValueError, message=expected_error_msg.format(inputs)):
213-
_AutoMLJob.start_new(auto_ml, inputs)
213+
AutoMLJob.start_new(auto_ml, inputs)
214214
sagemaker_session.auto_ml.assert_not_called()
215215

216216

@@ -330,6 +330,26 @@ def test_auto_ml_local_input(sagemaker_session):
330330
assert args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == DEFAULT_S3_INPUT_DATA
331331

332332

333+
def test_auto_ml_input(sagemaker_session):
334+
inputs = AutoMLInput(
335+
inputs=DEFAULT_S3_INPUT_DATA, target_attribute_name="target", compression="Gzip"
336+
)
337+
auto_ml = AutoML(
338+
role=ROLE, target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_session
339+
)
340+
auto_ml.fit(inputs)
341+
_, args = sagemaker_session.auto_ml.call_args
342+
assert args["input_config"] == [
343+
{
344+
"CompressionType": "Gzip",
345+
"DataSource": {
346+
"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": DEFAULT_S3_INPUT_DATA}
347+
},
348+
"TargetAttributeName": TARGET_ATTRIBUTE_NAME,
349+
}
350+
]
351+
352+
333353
def test_describe_auto_ml_job(sagemaker_session):
334354
auto_ml = AutoML(
335355
role=ROLE, target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_session

0 commit comments

Comments
 (0)