forked from aws/sagemaker-python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcandidate_estimator.py
336 lines (294 loc) · 14 KB
/
candidate_estimator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""A class for AutoML Job's Candidate."""
from __future__ import absolute_import
from six import string_types
from sagemaker import Session
from sagemaker.job import _Job
from sagemaker.utils import name_from_base
class CandidateEstimator(object):
"""A class for SageMaker AutoML Job Candidate"""
def __init__(self, candidate, sagemaker_session=None):
"""Constructor of CandidateEstimator.
Args:
candidate (dict): a dictionary of candidate returned by AutoML.list_candidates()
or AutoML.best_candidate().
sagemaker_session (sagemaker.session.Session): A SageMaker Session
object, used for SageMaker interactions (default: None). If not
specified, one is created using the default AWS configuration
chain.
"""
self.name = candidate["CandidateName"]
self.containers = candidate["InferenceContainers"]
self.steps = self._process_steps(candidate["CandidateSteps"])
self.sagemaker_session = sagemaker_session or Session()
def get_steps(self):
"""Get the step job of a candidate so that users can construct estimators/transformers
Returns:
list: a list of dictionaries that provide information about each step job's name,
type, inputs and description
"""
candidate_steps = []
for step in self.steps:
step_type = step["type"]
step_name = step["name"]
if step_type == "TrainingJob":
training_job = self.sagemaker_session.sagemaker_client.describe_training_job(
TrainingJobName=step_name
)
inputs = training_job["InputDataConfig"]
candidate_step = CandidateStep(step_name, inputs, step_type, training_job)
candidate_steps.append(candidate_step)
elif step_type == "TransformJob":
transform_job = self.sagemaker_session.sagemaker_client.describe_transform_job(
TransformJobName=step_name
)
inputs = transform_job["TransformInput"]
candidate_step = CandidateStep(step_name, inputs, step_type, transform_job)
candidate_steps.append(candidate_step)
return candidate_steps
def fit(
self,
inputs,
candidate_name=None,
volume_kms_key=None,
encrypt_inter_container_traffic=False,
vpc_config=None,
wait=True,
logs=True,
):
"""Rerun a candidate's step jobs with new input datasets or security config.
Args:
inputs (str or list[str]): Local path or S3 Uri where the training data is stored. If a
local path is provided, the dataset will be uploaded to an S3 location.
candidate_name (str): name of the candidate to be rerun, if None, candidate's original
name will be used.
volume_kms_key (str): The KMS key id to encrypt data on the storage volume attached to
the ML compute instance(s).
encrypt_inter_container_traffic (bool): To encrypt all communications between ML compute
instances in distributed training. Default: False.
vpc_config (dict): Specifies a VPC that jobs and hosted models have access to.
Control access to and from training and model containers by configuring the VPC
wait (bool): Whether the call should wait until all jobs completes (default: True).
logs (bool): Whether to show the logs produced by the job.
Only meaningful when wait is True (default: True).
"""
if logs and not wait:
raise ValueError(
"""Logs can only be shown if wait is set to True.
Please either set wait to True or set logs to False."""
)
self.name = candidate_name or self.name
running_jobs = {}
# convert inputs to s3_input format
if isinstance(inputs, string_types):
if not inputs.startswith("s3://"):
inputs = self.sagemaker_session.upload_data(inputs, key_prefix="auto-ml-input-data")
for step in self.steps:
step_type = step["type"]
step_name = step["name"]
if step_type == "TrainingJob":
# prepare inputs
input_dict = {}
if isinstance(inputs, string_types):
input_dict["train"] = _Job._format_string_uri_input(inputs)
else:
msg = "Cannot format input {}. Expecting a string."
raise ValueError(msg.format(inputs))
channels = [
_Job._convert_input_to_channel(name, input)
for name, input in input_dict.items()
]
desc = self.sagemaker_session.sagemaker_client.describe_training_job(
TrainingJobName=step_name
)
base_name = "sagemaker-automl-training-rerun"
step_name = name_from_base(base_name)
step["name"] = step_name
train_args = self._get_train_args(
desc,
channels,
step_name,
volume_kms_key,
encrypt_inter_container_traffic,
vpc_config,
)
self.sagemaker_session.train(**train_args)
running_jobs[step_name] = True
elif step_type == "TransformJob":
# prepare inputs
if not isinstance(inputs, string_types) or not inputs.startswith("s3://"):
msg = "Cannot format input {}. Expecting a string starts with file:// or s3://"
raise ValueError(msg.format(inputs))
desc = self.sagemaker_session.sagemaker_client.describe_transform_job(
TransformJobName=step_name
)
base_name = "sagemaker-automl-transform-rerun"
step_name = name_from_base(base_name)
step["name"] = step_name
transform_args = self._get_transform_args(desc, inputs, step_name, volume_kms_key)
self.sagemaker_session.transform(**transform_args)
running_jobs[step_name] = True
if wait:
while True:
for step in self.steps:
status = None
step_type = step["type"]
step_name = step["name"]
if step_type == "TrainingJob":
status = self.sagemaker_session.sagemaker_client.describe_training_job(
TrainingJobName=step_name
)["TrainingJobStatus"]
elif step_type == "TransformJob":
status = self.sagemaker_session.sagemaker_client.describe_transform_job(
TransformJobName=step_name
)["TransformJobStatus"]
if status in ("Completed", "Failed", "Stopped"):
running_jobs[step_name] = False
if self._check_all_job_finished(running_jobs):
break
def _check_all_job_finished(self, running_jobs):
"""Check if all step jobs are finished.
Args:
running_jobs (dict): a dictionary that keeps track of the status
of each step job.
Returns (bool): True if all step jobs are finished. False if one or
more step jobs are still running.
"""
for _, v in running_jobs.items():
if v:
return False
return True
def _get_train_args(
self, desc, inputs, name, volume_kms_key, encrypt_inter_container_traffic, vpc_config
):
"""Format training args to pass in sagemaker_session.train.
Args:
desc (dict): the response from DescribeTrainingJob API.
inputs (list): a list of input data channels.
name (str): the name of the step job.
volume_kms_key (str): The KMS key id to encrypt data on the storage volume attached to
the ML compute instance(s).
encrypt_inter_container_traffic (bool): To encrypt all communications between ML compute
instances in distributed training.
vpc_config (dict): Specifies a VPC that jobs and hosted models have access to.
Control access to and from training and model containers by configuring the VPC
Returns (dcit): a dictionary that can be used as args of
sagemaker_session.train method.
"""
train_args = {
"input_config": inputs,
"job_name": name,
"input_mode": desc["AlgorithmSpecification"]["TrainingInputMode"],
"role": desc["RoleArn"],
"output_config": desc["OutputDataConfig"],
"resource_config": desc["ResourceConfig"],
"image_uri": desc["AlgorithmSpecification"]["TrainingImage"],
"enable_network_isolation": desc["EnableNetworkIsolation"],
"encrypt_inter_container_traffic": encrypt_inter_container_traffic,
"train_use_spot_instances": desc["EnableManagedSpotTraining"],
"hyperparameters": {},
"stop_condition": {},
"metric_definitions": None,
"checkpoint_s3_uri": None,
"checkpoint_local_path": None,
"tags": [],
"vpc_config": None,
}
if volume_kms_key is not None:
train_args["resource_config"]["VolumeKmsKeyId"] = volume_kms_key
if "VpcConfig" in desc:
train_args["vpc_config"] = desc["VpcConfig"]
elif vpc_config is not None:
train_args["vpc_config"] = vpc_config
if "Hyperparameters" in desc:
train_args["hyperparameters"] = desc["Hyperparameters"]
if "CheckpointConfig" in desc:
train_args["checkpoint_s3_uri"] = desc["CheckpointConfig"]["S3Uri"]
train_args["checkpoint_local_path"] = desc["CheckpointConfig"]["LocalPath"]
if "StoppingCondition" in desc:
train_args["stop_condition"] = desc["StoppingCondition"]
return train_args
def _get_transform_args(self, desc, inputs, name, volume_kms_key):
"""Format training args to pass in sagemaker_session.train.
Args:
desc (dict): the response from DescribeTrainingJob API.
inputs (str): an S3 uri where new input dataset is stored.
name (str): the name of the step job.
volume_kms_key (str): The KMS key id to encrypt data on the storage volume attached to
the ML compute instance(s).
Returns (dcit): a dictionary that can be used as args of
sagemaker_session.transform method.
"""
transform_args = {}
transform_args["job_name"] = name
transform_args["model_name"] = desc["ModelName"]
transform_args["output_config"] = desc["TransformOutput"]
transform_args["resource_config"] = desc["TransformResources"]
transform_args["data_processing"] = desc["DataProcessing"]
transform_args["tags"] = []
transform_args["strategy"] = None
transform_args["max_concurrent_transforms"] = None
transform_args["max_payload"] = None
transform_args["env"] = None
transform_args["experiment_config"] = None
input_config = desc["TransformInput"]
input_config["DataSource"]["S3DataSource"]["S3Uri"] = inputs
transform_args["input_config"] = input_config
if volume_kms_key is not None:
transform_args["resource_config"]["VolumeKmsKeyId"] = volume_kms_key
if "BatchStrategy" in desc:
transform_args["strategy"] = desc["BatchStrategy"]
if "MaxConcurrentTransforms" in desc:
transform_args["max_concurrent_transforms"] = desc["MaxConcurrentTransforms"]
if "MaxPayloadInMB" in desc:
transform_args["max_payload"] = desc["MaxPayloadInMB"]
if "Environment" in desc:
transform_args["env"] = desc["Environment"]
return transform_args
def _process_steps(self, steps):
"""Extract candidate's step jobs name and type.
Args:
steps (list): a list of a candidate's step jobs.
Returns (list): a list of extracted information about step jobs'
name and type.
"""
processed_steps = []
for step in steps:
step_name = step["CandidateStepName"]
step_type = step["CandidateStepType"].split("::")[2]
processed_steps.append({"name": step_name, "type": step_type})
return processed_steps
class CandidateStep(object):
"""A class that maintains an AutoML Candidate step's name, inputs, type, and description."""
def __init__(self, name, inputs, step_type, description):
self._name = name
self._inputs = inputs
self._type = step_type
self._description = description
@property
def name(self):
"""Name of the candidate step -> (str)"""
return self._name
@property
def inputs(self):
"""Inputs of the candidate step -> (dict)"""
return self._inputs
@property
def type(self):
"""Type of the candidate step, Training or Transform -> (str)"""
return self._type
@property
def description(self):
"""Description of candidate step job -> (dict)"""
return self._description