Skip to content

Commit 684d6a9

Browse files
authored
Initial checkin of SageMaker HPO Analytics library. (aws#22)
* Initial checkin of SageMaker HPO Analytics library.
1 parent 4d7dda5 commit 684d6a9

11 files changed

+566
-2
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ doc/_templates
2424
venv/
2525
*~
2626
.pytest_cache/
27+
*.swp

CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ CHANGELOG
66
1.2.dev5
77
========
88

9+
* feature: Analytics functions for metrics in Training and HyperparameterTuning jobs
910
* bug-fix: Change module names to string type in __all__
1011
* feature: Save training output files in local mode
1112
* bug-fix: tensorflow-serving-api: SageMaker does not conflict with tensorflow-serving-api module version

src/sagemaker/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sagemaker.amazon.factorization_machines import FactorizationMachinesPredictor
2222
from sagemaker.amazon.ntm import NTM, NTMModel, NTMPredictor
2323
from sagemaker.amazon.randomcutforest import RandomCutForest, RandomCutForestModel, RandomCutForestPredictor
24+
from sagemaker.analytics import TrainingJobAnalytics, HyperparameterTuningJobAnalytics
2425

2526
from sagemaker.local.local_session import LocalSession
2627

@@ -39,4 +40,5 @@
3940
'FactorizationMachines', 'FactorizationMachinesModel', 'FactorizationMachinesPredictor',
4041
'RandomCutForest', 'RandomCutForestModel', 'RandomCutForestPredictor',
4142
'Model', 'NTM', 'NTMModel', 'NTMPredictor', 'RealTimePredictor', 'Session', 'LocalSession',
43+
'TrainingJobAnalytics', 'HyperparameterTuningJobAnalytics',
4244
'container_def', 's3_input', 'production_variant', 'get_execution_role']

src/sagemaker/analytics.py

+294
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import print_function, absolute_import
14+
15+
from abc import ABCMeta, abstractmethod
16+
from collections import defaultdict
17+
import datetime
18+
import logging
19+
20+
from six import with_metaclass
21+
22+
from sagemaker.session import Session
23+
from sagemaker.utils import DeferredError
24+
25+
try:
26+
import pandas as pd
27+
except ImportError as e:
28+
logging.warning("pandas failed to import. Analytics features will be impaired or broken.")
29+
# Any subsequent attempt to use pandas will raise the ImportError
30+
pd = DeferredError(e)
31+
32+
33+
class AnalyticsMetricsBase(with_metaclass(ABCMeta, object)):
34+
"""Base class for tuning job or training job analytics classes.
35+
Understands common functionality like persistence and caching.
36+
"""
37+
38+
def export_csv(self, filename):
39+
"""Persists the analytics dataframe to a file.
40+
41+
Args:
42+
filename (str): The name of the file to save to.
43+
"""
44+
self.dataframe().to_csv(filename)
45+
46+
def dataframe(self, force_refresh=False):
47+
"""A pandas dataframe with lots of interesting results about this object.
48+
Created by calling SageMaker List and Describe APIs and converting them into
49+
a convenient tabular summary.
50+
51+
Args:
52+
force_refresh (bool): Set to True to fetch the latest data from SageMaker API.
53+
"""
54+
if force_refresh:
55+
self.clear_cache()
56+
if self._dataframe is None:
57+
self._dataframe = self._fetch_dataframe()
58+
return self._dataframe
59+
60+
@abstractmethod
61+
def _fetch_dataframe(self):
62+
"""Sub-class must calculate the dataframe and return it.
63+
"""
64+
pass
65+
66+
def clear_cache(self):
67+
"""Clears the object of all local caches of API methods, so
68+
that the next time any properties are accessed they will be refreshed from
69+
the service.
70+
"""
71+
self._dataframe = None
72+
73+
74+
class HyperparameterTuningJobAnalytics(AnalyticsMetricsBase):
75+
"""Fetches results about this tuning job and makes them accessible for analytics.
76+
"""
77+
78+
def __init__(self, hyperparameter_tuning_job_name, sagemaker_session=None):
79+
"""Initialize an ``HyperparameterTuningJobAnalytics`` instance.
80+
81+
Args:
82+
hyperparameter_tuning_job_name (str): name of the HyperparameterTuningJob to
83+
analyze.
84+
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
85+
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
86+
using the default AWS configuration chain.
87+
"""
88+
sagemaker_session = sagemaker_session or Session()
89+
self._sage_client = sagemaker_session.sagemaker_client
90+
self._tuning_job_name = hyperparameter_tuning_job_name
91+
self.clear_cache()
92+
93+
@property
94+
def name(self):
95+
"""Name of the HyperparameterTuningJob being analyzed
96+
"""
97+
return self._tuning_job_name
98+
99+
def __repr__(self):
100+
return "<sagemaker.HyperparameterTuningJobAnalytics for %s>" % self.name
101+
102+
def clear_cache(self):
103+
"""Clears the object of all local caches of API methods.
104+
"""
105+
super(HyperparameterTuningJobAnalytics, self).clear_cache()
106+
self._tuning_job_describe_result = None
107+
self._training_job_summaries = None
108+
109+
def _fetch_dataframe(self):
110+
"""Returns a pandas dataframe with all the training jobs, their
111+
hyperparameters, results, and metadata about the training jobs.
112+
Includes a column to indicate that any job was the best seen so far.
113+
"""
114+
def reshape(training_summary):
115+
# Helper method to reshape a single training job summary into a dataframe record
116+
out = {}
117+
for k, v in training_summary['TunedHyperParameters'].items():
118+
# Something (bokeh?) gets confused with ints so convert to float
119+
try:
120+
v = float(v)
121+
except (TypeError, ValueError):
122+
pass
123+
out[k] = v
124+
out['TrainingJobName'] = training_summary['TrainingJobName']
125+
out['TrainingJobStatus'] = training_summary['TrainingJobStatus']
126+
out['FinalObjectiveValue'] = training_summary.get('FinalHyperParameterTuningJobObjectiveMetric',
127+
{}).get('Value')
128+
129+
start_time = training_summary['CreationTime']
130+
end_time = training_summary['TrainingEndTime']
131+
out['TrainingStartTime'] = start_time
132+
out['TrainingEndTime'] = end_time
133+
if start_time and end_time:
134+
out['TrainingElapsedTimeSeconds'] = (end_time - start_time).total_seconds()
135+
return out
136+
# Run that helper over all the summaries.
137+
df = pd.DataFrame([reshape(tjs) for tjs in self.training_job_summaries()])
138+
return df
139+
140+
@property
141+
def tuning_ranges(self):
142+
"""A dict describing the ranges of all tuned hyperparameters.
143+
Dict's key is the name of the hyper param. Dict's value is the range.
144+
"""
145+
out = {}
146+
for _, ranges in self.description()['HyperParameterTuningJobConfig']['ParameterRanges'].items():
147+
for param in ranges:
148+
out[param['Name']] = param
149+
return out
150+
151+
def description(self, force_refresh=False):
152+
"""Response to DescribeHyperParameterTuningJob
153+
154+
Args:
155+
force_refresh (bool): Set to True to fetch the latest data from SageMaker API.
156+
"""
157+
if force_refresh:
158+
self.clear_cache()
159+
if not self._tuning_job_describe_result:
160+
self._tuning_job_describe_result = self._sage_client.describe_hyper_parameter_tuning_job(
161+
HyperParameterTuningJobName=self.name
162+
)
163+
return self._tuning_job_describe_result
164+
165+
def training_job_summaries(self, force_refresh=False):
166+
"""A list of everything (paginated) from ListTrainingJobsForTuningJob
167+
168+
Args:
169+
force_refresh (bool): Set to True to fetch the latest data from SageMaker API.
170+
"""
171+
if force_refresh:
172+
self.clear_cache()
173+
if self._training_job_summaries is not None:
174+
return self._training_job_summaries
175+
output = []
176+
next_args = {}
177+
for count in range(100):
178+
logging.debug("Calling list_training_jobs_for_hyper_parameter_tuning_job %d" % count)
179+
raw_result = self._sage_client.list_training_jobs_for_hyper_parameter_tuning_job(
180+
HyperParameterTuningJobName=self.name, MaxResults=100, **next_args
181+
)
182+
new_output = raw_result['TrainingJobSummaries']
183+
output.extend(new_output)
184+
logging.debug("Got %d more TrainingJobs. Total so far: %d" % (len(new_output), len(output)))
185+
if ('NextToken' in raw_result) and (len(new_output) > 0):
186+
next_args['NextToken'] = raw_result['NextToken']
187+
else:
188+
break
189+
self._training_job_summaries = output
190+
return output
191+
192+
193+
class TrainingJobAnalytics(AnalyticsMetricsBase):
194+
"""Fetches training curve data from CloudWatch Metrics for a specific training job.
195+
"""
196+
197+
CLOUDWATCH_NAMESPACE = '/aws/sagemaker/HyperParameterTuningJobs'
198+
199+
def __init__(self, training_job_name, metric_names, sagemaker_session=None):
200+
"""Initialize an ``TrainingJobAnalytics`` instance.
201+
202+
Args:
203+
training_job_name (str): name of the TrainingJob to analyze.
204+
metric_names (list): string names of all the metrics to collect for this training job
205+
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
206+
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
207+
using the default AWS configuration chain.
208+
"""
209+
sagemaker_session = sagemaker_session or Session()
210+
self._sage_client = sagemaker_session.sagemaker_client
211+
self._cloudwatch = sagemaker_session.boto_session.client('cloudwatch')
212+
self._training_job_name = training_job_name
213+
self._metric_names = metric_names
214+
self.clear_cache()
215+
216+
@property
217+
def name(self):
218+
"""Name of the TrainingJob being analyzed
219+
"""
220+
return self._training_job_name
221+
222+
def __repr__(self):
223+
return "<sagemaker.TrainingJobAnalytics for %s>" % self.name
224+
225+
def clear_cache(self):
226+
"""Clears the object of all local caches of API methods, so
227+
that the next time any properties are accessed they will be refreshed from
228+
the service.
229+
"""
230+
super(TrainingJobAnalytics, self).clear_cache()
231+
self._data = defaultdict(list)
232+
self._time_interval = self._determine_timeinterval()
233+
234+
def _determine_timeinterval(self):
235+
"""Returns a dict with two datetime objects, start_time and end_time
236+
covering the interval of the training job
237+
"""
238+
description = self._sage_client.describe_training_job(TrainingJobName=self.name)
239+
start_time = description[u'TrainingStartTime'] # datetime object
240+
end_time = description.get(u'TrainingEndTime', datetime.datetime.utcnow())
241+
return {
242+
'start_time': start_time,
243+
'end_time': end_time,
244+
}
245+
246+
def _fetch_dataframe(self):
247+
for metric_name in self._metric_names:
248+
self._fetch_metric(metric_name)
249+
return pd.DataFrame(self._data)
250+
251+
def _fetch_metric(self, metric_name):
252+
"""Fetches all the values of a named metric, and adds them to _data
253+
"""
254+
request = {
255+
'Namespace': self.CLOUDWATCH_NAMESPACE,
256+
'MetricName': metric_name,
257+
'Dimensions': [
258+
{
259+
'Name': 'TrainingJobName',
260+
'Value': self.name
261+
}
262+
],
263+
'StartTime': self._time_interval['start_time'],
264+
'EndTime': self._time_interval['end_time'],
265+
'Period': 60,
266+
'Statistics': ['Average'],
267+
}
268+
raw_cwm_data = self._cloudwatch.get_metric_statistics(**request)['Datapoints']
269+
if len(raw_cwm_data) == 0:
270+
logging.warning("Warning: No metrics called %s found" % metric_name)
271+
return
272+
273+
# Process data: normalize to starting time, and sort.
274+
base_time = min(raw_cwm_data, key=lambda pt: pt['Timestamp'])['Timestamp']
275+
all_xy = []
276+
for pt in raw_cwm_data:
277+
y = pt['Average']
278+
x = (pt['Timestamp'] - base_time).total_seconds()
279+
all_xy.append([x, y])
280+
all_xy = sorted(all_xy, key=lambda x: x[0])
281+
282+
# Store everything in _data to make a dataframe from
283+
for elapsed_seconds, value in all_xy:
284+
self._add_single_metric(elapsed_seconds, metric_name, value)
285+
286+
def _add_single_metric(self, timestamp, metric_name, value):
287+
"""Stores a single metric in the _data dict which can be
288+
converted to a dataframe.
289+
"""
290+
# note that this method is built this way to make it possible to
291+
# support live-refreshing charts in Bokeh at some point in the future.
292+
self._data['timestamp'].append(timestamp)
293+
self._data['metric_name'].append(metric_name)
294+
self._data['value'].append(value)

src/sagemaker/estimator.py

+9
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from sagemaker.session import Session
3131
from sagemaker.session import s3_input
3232
from sagemaker.utils import base_name_from_image, name_from_base, get_config_value
33+
from sagemaker.analytics import TrainingJobAnalytics
3334

3435

3536
class EstimatorBase(with_metaclass(ABCMeta, object)):
@@ -317,6 +318,14 @@ def delete_endpoint(self):
317318
raise ValueError('Endpoint was not created yet')
318319
self.sagemaker_session.delete_endpoint(self.latest_training_job.name)
319320

321+
@property
322+
def training_job_analytics(self):
323+
"""Returns a TrainingJobAnalytics object for the current training job.
324+
"""
325+
if self._current_job_name is None:
326+
raise ValueError('Estimator is not associated with a TrainingJob')
327+
return TrainingJobAnalytics(self._current_job_name)
328+
320329

321330
class _TrainingJob(_Job):
322331
def __init__(self, sagemaker_session, training_job_name):

src/sagemaker/tuner.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import inspect
1616
import json
1717

18+
from sagemaker.analytics import HyperparameterTuningJobAnalytics
1819
from sagemaker.estimator import Framework
1920
from sagemaker.job import _Job
2021
from sagemaker.utils import base_name_from_image, name_from_base
@@ -201,6 +202,19 @@ def hyperparameter_ranges(self):
201202
hyperparameter_ranges[range_type + 'ParameterRanges'] = parameter_ranges
202203
return hyperparameter_ranges
203204

205+
@property
206+
def sagemaker_session(self):
207+
"""The tuner shares the sagemaker_session object with its estimator.
208+
Convenience method.
209+
"""
210+
return self.estimator.sagemaker_session
211+
212+
def analytics(self):
213+
"""An instance of HyperparameterTuningJobAnalytics for this latest tuning job of this tuner.
214+
Analytics olbject gives you access to tuning results summarized into a pandas dataframe.
215+
"""
216+
return HyperparameterTuningJobAnalytics(self.latest_tuning_job, self.sagemaker_session)
217+
204218
def _validate_parameter_ranges(self):
205219
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
206220

@@ -262,7 +276,7 @@ def start_new(cls, tuner, inputs):
262276
resource_config=(config['resource_config']),
263277
stop_condition=(config['stop_condition']))
264278

265-
return cls(tuner.estimator.sagemaker_session, tuning_job_name)
279+
return cls(tuner.sagemaker_session, tuning_job_name)
266280

267281
def stop(self):
268282
self.sagemaker_session.stop_tuning_job(HyperParameterTuningJobName=self.name)

0 commit comments

Comments
 (0)