Skip to content

Commit 71f9bb7

Browse files
chuyang-dengknakad
authored andcommitted
feature: introduce SageMaker AutoML (aws#257)
Introduce AutoML feature PySDK support.
1 parent 6273bdd commit 71f9bb7

File tree

13 files changed

+3141
-12
lines changed

13 files changed

+3141
-12
lines changed

buildspec.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ phases:
1919
- start_time=`date +%s`
2020
- |
2121
if has-matching-changes "tests/" "src/*.py" "setup.py" "setup.cfg" "buildspec.yml"; then
22-
tox -e py36 -- tests/integ -m "not local_mode" -n 48 --reruns 3 --reruns-delay 5 --durations 50 --boto-config '{"region_name": "us-east-2"}'
22+
tox -e py36 -- tests/integ -m "not local_mode" -n 48 --reruns 3 --reruns-delay 5 --durations 50 --boto-config '{"region_name": "us-west-2"}'
2323
fi
2424
- ./ci-scripts/displaytime.sh 'py36 tests/integ' $start_time
2525

src/sagemaker/automl/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.

src/sagemaker/automl/automl.py

Lines changed: 551 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class for AutoML Job's Candidate."""
14+
from __future__ import absolute_import
15+
16+
from six import string_types
17+
18+
from sagemaker import Session
19+
from sagemaker.job import _Job
20+
from sagemaker.utils import name_from_base
21+
22+
23+
class CandidateEstimator(object):
24+
"""A class for SageMaker AutoML Job Candidate"""
25+
26+
def __init__(self, candidate, sagemaker_session=None):
27+
"""Constructor of CandidateEstimator.
28+
29+
Args:
30+
candidate (dict): a dictionary of candidate returned by AutoML.list_candidates()
31+
or AutoML.best_candidate().
32+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
33+
object, used for SageMaker interactions (default: None). If not
34+
specified, one is created using the default AWS configuration
35+
chain.
36+
"""
37+
self.name = candidate["CandidateName"]
38+
self.containers = candidate["InferenceContainers"]
39+
self.steps = self._process_steps(candidate["CandidateSteps"])
40+
self.sagemaker_session = sagemaker_session or Session()
41+
42+
def get_steps(self):
43+
"""Get the step job of a candidate so that users can construct estimators/transformers
44+
45+
Returns:
46+
list: a list of dictionaries that provide information about each step job's name,
47+
type, inputs and description
48+
"""
49+
candidate_steps = []
50+
for step in self.steps:
51+
step_type = step["type"]
52+
step_name = step["name"]
53+
if step_type == "TrainingJob":
54+
training_job = self.sagemaker_session.sagemaker_client.describe_training_job(
55+
TrainingJobName=step_name
56+
)
57+
58+
inputs = training_job["InputDataConfig"]
59+
candidate_step = {
60+
"name": step_name,
61+
"inputs": inputs,
62+
"type": step_type,
63+
"desc": training_job,
64+
}
65+
candidate_steps.append(candidate_step)
66+
elif step_type == "TransformJob":
67+
transform_job = self.sagemaker_session.sagemaker_client.describe_transform_job(
68+
TransformJobName=step_name
69+
)
70+
inputs = transform_job["TransformInput"]
71+
candidate_step = {
72+
"name": step_name,
73+
"inputs": inputs,
74+
"type": step_type,
75+
"desc": transform_job,
76+
}
77+
candidate_steps.append(candidate_step)
78+
return candidate_steps
79+
80+
def fit(
81+
self,
82+
inputs,
83+
candidate_name=None,
84+
volume_kms_key=None,
85+
encrypt_inter_container_traffic=False,
86+
vpc_config=None,
87+
wait=True,
88+
logs=True,
89+
):
90+
"""Rerun a candidate's step jobs with new input datasets or security config.
91+
92+
Args:
93+
inputs (str or list[str]): Local path or S3 Uri where the training data is stored. If a
94+
local path is provided, the dataset will be uploaded to an S3 location.
95+
candidate_name (str): name of the candidate to be rerun, if None, candidate's original
96+
name will be used.
97+
volume_kms_key (str): The KMS key id to encrypt data on the storage volume attached to
98+
the ML compute instance(s).
99+
encrypt_inter_container_traffic (bool): To encrypt all communications between ML compute
100+
instances in distributed training. Default: False.
101+
vpc_config (dict): Specifies a VPC that jobs and hosted models have access to.
102+
Control access to and from training and model containers by configuring the VPC
103+
wait (bool): Whether the call should wait until all jobs completes (default: True).
104+
logs (bool): Whether to show the logs produced by the job.
105+
Only meaningful when wait is True (default: True).
106+
"""
107+
if logs and not wait:
108+
raise ValueError(
109+
"""Logs can only be shown if wait is set to True.
110+
Please either set wait to True or set logs to False."""
111+
)
112+
113+
self.name = candidate_name or self.name
114+
running_jobs = {}
115+
116+
# convert inputs to s3_input format
117+
if isinstance(inputs, string_types):
118+
if not inputs.startswith("s3://"):
119+
inputs = self.sagemaker_session.upload_data(inputs, key_prefix="auto-ml-input-data")
120+
121+
for step in self.steps:
122+
step_type = step["type"]
123+
step_name = step["name"]
124+
if step_type == "TrainingJob":
125+
# prepare inputs
126+
input_dict = {}
127+
if isinstance(inputs, string_types):
128+
input_dict["train"] = _Job._format_string_uri_input(inputs)
129+
else:
130+
msg = "Cannot format input {}. Expecting a string."
131+
raise ValueError(msg.format(inputs))
132+
133+
channels = [
134+
_Job._convert_input_to_channel(name, input)
135+
for name, input in input_dict.items()
136+
]
137+
138+
desc = self.sagemaker_session.sagemaker_client.describe_training_job(
139+
TrainingJobName=step_name
140+
)
141+
base_name = "sagemaker-automl-training-rerun"
142+
step_name = name_from_base(base_name)
143+
step["name"] = step_name
144+
train_args = self._get_train_args(
145+
desc,
146+
channels,
147+
step_name,
148+
volume_kms_key,
149+
encrypt_inter_container_traffic,
150+
vpc_config,
151+
)
152+
self.sagemaker_session.train(**train_args)
153+
running_jobs[step_name] = True
154+
155+
elif step_type == "TransformJob":
156+
# prepare inputs
157+
if not isinstance(inputs, string_types) or not inputs.startswith("s3://"):
158+
msg = "Cannot format input {}. Expecting a string starts with file:// or s3://"
159+
raise ValueError(msg.format(inputs))
160+
161+
desc = self.sagemaker_session.sagemaker_client.describe_transform_job(
162+
TransformJobName=step_name
163+
)
164+
base_name = "sagemaker-automl-transform-rerun"
165+
step_name = name_from_base(base_name)
166+
step["name"] = step_name
167+
transform_args = self._get_transform_args(desc, inputs, step_name, volume_kms_key)
168+
self.sagemaker_session.transform(**transform_args)
169+
running_jobs[step_name] = True
170+
171+
if wait:
172+
while True:
173+
for step in self.steps:
174+
status = None
175+
step_type = step["type"]
176+
step_name = step["name"]
177+
if step_type == "TrainingJob":
178+
status = self.sagemaker_session.sagemaker_client.describe_training_job(
179+
TrainingJobName=step_name
180+
)["TrainingJobStatus"]
181+
elif step_type == "TransformJob":
182+
status = self.sagemaker_session.sagemaker_client.describe_transform_job(
183+
TransformJobName=step_name
184+
)["TransformJobStatus"]
185+
if status in ("Completed", "Failed", "Stopped"):
186+
running_jobs[step_name] = False
187+
if self._check_all_job_finished(running_jobs):
188+
break
189+
190+
def _check_all_job_finished(self, running_jobs):
191+
"""Check if all step jobs are finished.
192+
193+
Args:
194+
running_jobs (dict): a dictionary that keeps track of the status
195+
of each step job.
196+
197+
Returns (bool): True if all step jobs are finished. False if one or
198+
more step jobs are still running.
199+
"""
200+
for _, v in running_jobs.items():
201+
if v:
202+
return False
203+
return True
204+
205+
def _get_train_args(
206+
self, desc, inputs, name, volume_kms_key, encrypt_inter_container_traffic, vpc_config
207+
):
208+
"""Format training args to pass in sagemaker_session.train.
209+
210+
Args:
211+
desc (dict): the response from DescribeTrainingJob API.
212+
inputs (list): a list of input data channels.
213+
name (str): the name of the step job.
214+
volume_kms_key (str): The KMS key id to encrypt data on the storage volume attached to
215+
the ML compute instance(s).
216+
encrypt_inter_container_traffic (bool): To encrypt all communications between ML compute
217+
instances in distributed training.
218+
vpc_config (dict): Specifies a VPC that jobs and hosted models have access to.
219+
Control access to and from training and model containers by configuring the VPC
220+
221+
Returns (dcit): a dictionary that can be used as args of
222+
sagemaker_session.train method.
223+
"""
224+
train_args = {}
225+
train_args["input_config"] = inputs
226+
train_args["job_name"] = name
227+
train_args["input_mode"] = desc["AlgorithmSpecification"]["TrainingInputMode"]
228+
train_args["role"] = desc["RoleArn"]
229+
train_args["output_config"] = desc["OutputDataConfig"]
230+
train_args["resource_config"] = desc["ResourceConfig"]
231+
train_args["image"] = desc["AlgorithmSpecification"]["TrainingImage"]
232+
train_args["enable_network_isolation"] = desc["EnableNetworkIsolation"]
233+
train_args["encrypt_inter_container_traffic"] = encrypt_inter_container_traffic
234+
train_args["train_use_spot_instances"] = desc["EnableManagedSpotTraining"]
235+
train_args["hyperparameters"] = {}
236+
train_args["stop_condition"] = {}
237+
train_args["metric_definitions"] = None
238+
train_args["checkpoint_s3_uri"] = None
239+
train_args["checkpoint_local_path"] = None
240+
train_args["tags"] = []
241+
train_args["vpc_config"] = None
242+
243+
if volume_kms_key is not None:
244+
train_args["resource_config"]["VolumeKmsKeyId"] = volume_kms_key
245+
if "VpcConfig" in desc:
246+
train_args["vpc_config"] = desc["VpcConfig"]
247+
elif vpc_config is not None:
248+
train_args["vpc_config"] = vpc_config
249+
if "Hyperparameters" in desc:
250+
train_args["hyperparameters"] = desc["Hyperparameters"]
251+
if "CheckpointConfig" in desc:
252+
train_args["checkpoint_s3_uri"] = desc["CheckpointConfig"]["S3Uri"]
253+
train_args["checkpoint_local_path"] = desc["CheckpointConfig"]["LocalPath"]
254+
if "StoppingCondition" in desc:
255+
train_args["stop_condition"] = desc["StoppingCondition"]
256+
return train_args
257+
258+
def _get_transform_args(self, desc, inputs, name, volume_kms_key):
259+
"""Format training args to pass in sagemaker_session.train.
260+
261+
Args:
262+
desc (dict): the response from DescribeTrainingJob API.
263+
inputs (str): an S3 uri where new input dataset is stored.
264+
name (str): the name of the step job.
265+
volume_kms_key (str): The KMS key id to encrypt data on the storage volume attached to
266+
the ML compute instance(s).
267+
268+
Returns (dcit): a dictionary that can be used as args of
269+
sagemaker_session.transform method.
270+
"""
271+
transform_args = {}
272+
transform_args["job_name"] = name
273+
transform_args["model_name"] = desc["ModelName"]
274+
transform_args["output_config"] = desc["TransformOutput"]
275+
transform_args["resource_config"] = desc["TransformResources"]
276+
transform_args["data_processing"] = desc["DataProcessing"]
277+
transform_args["tags"] = []
278+
transform_args["strategy"] = None
279+
transform_args["max_concurrent_transforms"] = None
280+
transform_args["max_payload"] = None
281+
transform_args["env"] = None
282+
283+
input_config = desc["TransformInput"]
284+
input_config["DataSource"]["S3DataSource"]["S3Uri"] = inputs
285+
transform_args["input_config"] = input_config
286+
287+
if volume_kms_key is not None:
288+
transform_args["resource_config"]["VolumeKmsKeyId"] = volume_kms_key
289+
if "BatchStrategy" in desc:
290+
transform_args["strategy"] = desc["BatchStrategy"]
291+
if "MaxConcurrentTransforms" in desc:
292+
transform_args["max_concurrent_transforms"] = desc["MaxConcurrentTransforms"]
293+
if "MaxPayloadInMB" in desc:
294+
transform_args["max_payload"] = desc["MaxPayloadInMB"]
295+
if "Environment" in desc:
296+
transform_args["env"] = desc["Environment"]
297+
298+
return transform_args
299+
300+
def _process_steps(self, steps):
301+
"""Extract candidate's step jobs name and type.
302+
303+
Args:
304+
steps (list): a list of a candidate's step jobs.
305+
306+
Returns (list): a list of extracted information about step jobs'
307+
name and type.
308+
"""
309+
processed_steps = []
310+
for step in steps:
311+
step_name = step["CandidateStepName"]
312+
step_type = step["CandidateStepType"].split("::")[2]
313+
processed_steps.append({"name": step_name, "type": step_type})
314+
return processed_steps

src/sagemaker/inputs.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,14 @@ class s3_input(object):
2828
def __init__(
2929
self,
3030
s3_data,
31-
distribution="FullyReplicated",
31+
distribution=None,
3232
compression=None,
3333
content_type=None,
3434
record_wrapping=None,
3535
s3_data_type="S3Prefix",
3636
input_mode=None,
3737
attribute_names=None,
38+
target_attribute_name=None,
3839
shuffle_config=None,
3940
):
4041
"""Create a definition for input data used by an SageMaker training job.
@@ -69,21 +70,23 @@ def __init__(
6970
7071
attribute_names (list[str]): A list of one or more attribute names to use that are
7172
found in a specified AugmentedManifestFile.
73+
target_attribute_name (str): The name of the attribute will be predicted (classified)
74+
in a SageMaker AutoML job. It is required if the input is for SageMaker AutoML job.
7275
shuffle_config (ShuffleConfig): If specified this configuration enables shuffling on
7376
this channel. See the SageMaker API documentation for more info:
7477
https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html
7578
"""
7679

7780
self.config = {
78-
"DataSource": {
79-
"S3DataSource": {
80-
"S3DataDistributionType": distribution,
81-
"S3DataType": s3_data_type,
82-
"S3Uri": s3_data,
83-
}
84-
}
81+
"DataSource": {"S3DataSource": {"S3DataType": s3_data_type, "S3Uri": s3_data}}
8582
}
8683

84+
if not (target_attribute_name or distribution):
85+
distribution = "FullyReplicated"
86+
87+
if distribution is not None:
88+
self.config["DataSource"]["S3DataSource"]["S3DataDistributionType"] = distribution
89+
8790
if compression is not None:
8891
self.config["CompressionType"] = compression
8992
if content_type is not None:
@@ -94,6 +97,8 @@ def __init__(
9497
self.config["InputMode"] = input_mode
9598
if attribute_names is not None:
9699
self.config["DataSource"]["S3DataSource"]["AttributeNames"] = attribute_names
100+
if target_attribute_name is not None:
101+
self.config["TargetAttributeName"] = target_attribute_name
97102
if shuffle_config is not None:
98103
self.config["ShuffleConfig"] = {"Seed": shuffle_config.seed}
99104

0 commit comments

Comments
 (0)