|
| 1 | +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +from __future__ import print_function, absolute_import |
| 14 | + |
| 15 | +from abc import ABCMeta, abstractmethod |
| 16 | +from collections import defaultdict |
| 17 | +import datetime |
| 18 | +import logging |
| 19 | + |
| 20 | +from six import with_metaclass |
| 21 | + |
| 22 | +from sagemaker.session import Session |
| 23 | +from sagemaker.utils import DeferredError |
| 24 | + |
| 25 | +try: |
| 26 | + import pandas as pd |
| 27 | +except ImportError as e: |
| 28 | + logging.warning("pandas failed to import. Analytics features will be impaired or broken.") |
| 29 | + # Any subsequent attempt to use pandas will raise the ImportError |
| 30 | + pd = DeferredError(e) |
| 31 | + |
| 32 | + |
| 33 | +class AnalyticsMetricsBase(with_metaclass(ABCMeta, object)): |
| 34 | + """Base class for tuning job or training job analytics classes. |
| 35 | + Understands common functionality like persistence and caching. |
| 36 | + """ |
| 37 | + |
| 38 | + def export_csv(self, filename): |
| 39 | + """Persists the analytics dataframe to a file. |
| 40 | +
|
| 41 | + Args: |
| 42 | + filename (str): The name of the file to save to. |
| 43 | + """ |
| 44 | + self.dataframe().to_csv(filename) |
| 45 | + |
| 46 | + def dataframe(self, force_refresh=False): |
| 47 | + """A pandas dataframe with lots of interesting results about this object. |
| 48 | + Created by calling SageMaker List and Describe APIs and converting them into |
| 49 | + a convenient tabular summary. |
| 50 | +
|
| 51 | + Args: |
| 52 | + force_refresh (bool): Set to True to fetch the latest data from SageMaker API. |
| 53 | + """ |
| 54 | + if force_refresh: |
| 55 | + self.clear_cache() |
| 56 | + if self._dataframe is None: |
| 57 | + self._dataframe = self._fetch_dataframe() |
| 58 | + return self._dataframe |
| 59 | + |
| 60 | + @abstractmethod |
| 61 | + def _fetch_dataframe(self): |
| 62 | + """Sub-class must calculate the dataframe and return it. |
| 63 | + """ |
| 64 | + pass |
| 65 | + |
| 66 | + def clear_cache(self): |
| 67 | + """Clears the object of all local caches of API methods, so |
| 68 | + that the next time any properties are accessed they will be refreshed from |
| 69 | + the service. |
| 70 | + """ |
| 71 | + self._dataframe = None |
| 72 | + |
| 73 | + |
| 74 | +class HyperparameterTuningJobAnalytics(AnalyticsMetricsBase): |
| 75 | + """Fetches results about this tuning job and makes them accessible for analytics. |
| 76 | + """ |
| 77 | + |
| 78 | + def __init__(self, hyperparameter_tuning_job_name, sagemaker_session=None): |
| 79 | + """Initialize an ``HyperparameterTuningJobAnalytics`` instance. |
| 80 | +
|
| 81 | + Args: |
| 82 | + hyperparameter_tuning_job_name (str): name of the HyperparameterTuningJob to |
| 83 | + analyze. |
| 84 | + sagemaker_session (sagemaker.session.Session): Session object which manages interactions with |
| 85 | + Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one |
| 86 | + using the default AWS configuration chain. |
| 87 | + """ |
| 88 | + sagemaker_session = sagemaker_session or Session() |
| 89 | + self._sage_client = sagemaker_session.sagemaker_client |
| 90 | + self._tuning_job_name = hyperparameter_tuning_job_name |
| 91 | + self.clear_cache() |
| 92 | + |
| 93 | + @property |
| 94 | + def name(self): |
| 95 | + """Name of the HyperparameterTuningJob being analyzed |
| 96 | + """ |
| 97 | + return self._tuning_job_name |
| 98 | + |
| 99 | + def __repr__(self): |
| 100 | + return "<sagemaker.HyperparameterTuningJobAnalytics for %s>" % self.name |
| 101 | + |
| 102 | + def clear_cache(self): |
| 103 | + """Clears the object of all local caches of API methods. |
| 104 | + """ |
| 105 | + super(HyperparameterTuningJobAnalytics, self).clear_cache() |
| 106 | + self._tuning_job_describe_result = None |
| 107 | + self._training_job_summaries = None |
| 108 | + |
| 109 | + def _fetch_dataframe(self): |
| 110 | + """Returns a pandas dataframe with all the training jobs, their |
| 111 | + hyperparameters, results, and metadata about the training jobs. |
| 112 | + Includes a column to indicate that any job was the best seen so far. |
| 113 | + """ |
| 114 | + def reshape(training_summary): |
| 115 | + # Helper method to reshape a single training job summary into a dataframe record |
| 116 | + out = {} |
| 117 | + for k, v in training_summary['TunedHyperParameters'].items(): |
| 118 | + # Something (bokeh?) gets confused with ints so convert to float |
| 119 | + try: |
| 120 | + v = float(v) |
| 121 | + except (TypeError, ValueError): |
| 122 | + pass |
| 123 | + out[k] = v |
| 124 | + out['TrainingJobName'] = training_summary['TrainingJobName'] |
| 125 | + out['TrainingJobStatus'] = training_summary['TrainingJobStatus'] |
| 126 | + out['FinalObjectiveValue'] = training_summary.get('FinalHyperParameterTuningJobObjectiveMetric', |
| 127 | + {}).get('Value') |
| 128 | + |
| 129 | + start_time = training_summary['CreationTime'] |
| 130 | + end_time = training_summary['TrainingEndTime'] |
| 131 | + out['TrainingStartTime'] = start_time |
| 132 | + out['TrainingEndTime'] = end_time |
| 133 | + if start_time and end_time: |
| 134 | + out['TrainingElapsedTimeSeconds'] = (end_time - start_time).total_seconds() |
| 135 | + return out |
| 136 | + # Run that helper over all the summaries. |
| 137 | + df = pd.DataFrame([reshape(tjs) for tjs in self.training_job_summaries()]) |
| 138 | + return df |
| 139 | + |
| 140 | + @property |
| 141 | + def tuning_ranges(self): |
| 142 | + """A dict describing the ranges of all tuned hyperparameters. |
| 143 | + Dict's key is the name of the hyper param. Dict's value is the range. |
| 144 | + """ |
| 145 | + out = {} |
| 146 | + for _, ranges in self.description()['HyperParameterTuningJobConfig']['ParameterRanges'].items(): |
| 147 | + for param in ranges: |
| 148 | + out[param['Name']] = param |
| 149 | + return out |
| 150 | + |
| 151 | + def description(self, force_refresh=False): |
| 152 | + """Response to DescribeHyperParameterTuningJob |
| 153 | +
|
| 154 | + Args: |
| 155 | + force_refresh (bool): Set to True to fetch the latest data from SageMaker API. |
| 156 | + """ |
| 157 | + if force_refresh: |
| 158 | + self.clear_cache() |
| 159 | + if not self._tuning_job_describe_result: |
| 160 | + self._tuning_job_describe_result = self._sage_client.describe_hyper_parameter_tuning_job( |
| 161 | + HyperParameterTuningJobName=self.name |
| 162 | + ) |
| 163 | + return self._tuning_job_describe_result |
| 164 | + |
| 165 | + def training_job_summaries(self, force_refresh=False): |
| 166 | + """A list of everything (paginated) from ListTrainingJobsForTuningJob |
| 167 | +
|
| 168 | + Args: |
| 169 | + force_refresh (bool): Set to True to fetch the latest data from SageMaker API. |
| 170 | + """ |
| 171 | + if force_refresh: |
| 172 | + self.clear_cache() |
| 173 | + if self._training_job_summaries is not None: |
| 174 | + return self._training_job_summaries |
| 175 | + output = [] |
| 176 | + next_args = {} |
| 177 | + for count in range(100): |
| 178 | + logging.debug("Calling list_training_jobs_for_hyper_parameter_tuning_job %d" % count) |
| 179 | + raw_result = self._sage_client.list_training_jobs_for_hyper_parameter_tuning_job( |
| 180 | + HyperParameterTuningJobName=self.name, MaxResults=100, **next_args |
| 181 | + ) |
| 182 | + new_output = raw_result['TrainingJobSummaries'] |
| 183 | + output.extend(new_output) |
| 184 | + logging.debug("Got %d more TrainingJobs. Total so far: %d" % (len(new_output), len(output))) |
| 185 | + if ('NextToken' in raw_result) and (len(new_output) > 0): |
| 186 | + next_args['NextToken'] = raw_result['NextToken'] |
| 187 | + else: |
| 188 | + break |
| 189 | + self._training_job_summaries = output |
| 190 | + return output |
| 191 | + |
| 192 | + |
| 193 | +class TrainingJobAnalytics(AnalyticsMetricsBase): |
| 194 | + """Fetches training curve data from CloudWatch Metrics for a specific training job. |
| 195 | + """ |
| 196 | + |
| 197 | + CLOUDWATCH_NAMESPACE = '/aws/sagemaker/HyperParameterTuningJobs' |
| 198 | + |
| 199 | + def __init__(self, training_job_name, metric_names, sagemaker_session=None): |
| 200 | + """Initialize an ``TrainingJobAnalytics`` instance. |
| 201 | +
|
| 202 | + Args: |
| 203 | + training_job_name (str): name of the TrainingJob to analyze. |
| 204 | + metric_names (list): string names of all the metrics to collect for this training job |
| 205 | + sagemaker_session (sagemaker.session.Session): Session object which manages interactions with |
| 206 | + Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one |
| 207 | + using the default AWS configuration chain. |
| 208 | + """ |
| 209 | + sagemaker_session = sagemaker_session or Session() |
| 210 | + self._sage_client = sagemaker_session.sagemaker_client |
| 211 | + self._cloudwatch = sagemaker_session.boto_session.client('cloudwatch') |
| 212 | + self._training_job_name = training_job_name |
| 213 | + self._metric_names = metric_names |
| 214 | + self.clear_cache() |
| 215 | + |
| 216 | + @property |
| 217 | + def name(self): |
| 218 | + """Name of the TrainingJob being analyzed |
| 219 | + """ |
| 220 | + return self._training_job_name |
| 221 | + |
| 222 | + def __repr__(self): |
| 223 | + return "<sagemaker.TrainingJobAnalytics for %s>" % self.name |
| 224 | + |
| 225 | + def clear_cache(self): |
| 226 | + """Clears the object of all local caches of API methods, so |
| 227 | + that the next time any properties are accessed they will be refreshed from |
| 228 | + the service. |
| 229 | + """ |
| 230 | + super(TrainingJobAnalytics, self).clear_cache() |
| 231 | + self._data = defaultdict(list) |
| 232 | + self._time_interval = self._determine_timeinterval() |
| 233 | + |
| 234 | + def _determine_timeinterval(self): |
| 235 | + """Returns a dict with two datetime objects, start_time and end_time |
| 236 | + covering the interval of the training job |
| 237 | + """ |
| 238 | + description = self._sage_client.describe_training_job(TrainingJobName=self.name) |
| 239 | + start_time = description[u'TrainingStartTime'] # datetime object |
| 240 | + end_time = description.get(u'TrainingEndTime', datetime.datetime.utcnow()) |
| 241 | + return { |
| 242 | + 'start_time': start_time, |
| 243 | + 'end_time': end_time, |
| 244 | + } |
| 245 | + |
| 246 | + def _fetch_dataframe(self): |
| 247 | + for metric_name in self._metric_names: |
| 248 | + self._fetch_metric(metric_name) |
| 249 | + return pd.DataFrame(self._data) |
| 250 | + |
| 251 | + def _fetch_metric(self, metric_name): |
| 252 | + """Fetches all the values of a named metric, and adds them to _data |
| 253 | + """ |
| 254 | + request = { |
| 255 | + 'Namespace': self.CLOUDWATCH_NAMESPACE, |
| 256 | + 'MetricName': metric_name, |
| 257 | + 'Dimensions': [ |
| 258 | + { |
| 259 | + 'Name': 'TrainingJobName', |
| 260 | + 'Value': self.name |
| 261 | + } |
| 262 | + ], |
| 263 | + 'StartTime': self._time_interval['start_time'], |
| 264 | + 'EndTime': self._time_interval['end_time'], |
| 265 | + 'Period': 60, |
| 266 | + 'Statistics': ['Average'], |
| 267 | + } |
| 268 | + raw_cwm_data = self._cloudwatch.get_metric_statistics(**request)['Datapoints'] |
| 269 | + if len(raw_cwm_data) == 0: |
| 270 | + logging.warning("Warning: No metrics called %s found" % metric_name) |
| 271 | + return |
| 272 | + |
| 273 | + # Process data: normalize to starting time, and sort. |
| 274 | + base_time = min(raw_cwm_data, key=lambda pt: pt['Timestamp'])['Timestamp'] |
| 275 | + all_xy = [] |
| 276 | + for pt in raw_cwm_data: |
| 277 | + y = pt['Average'] |
| 278 | + x = (pt['Timestamp'] - base_time).total_seconds() |
| 279 | + all_xy.append([x, y]) |
| 280 | + all_xy = sorted(all_xy, key=lambda x: x[0]) |
| 281 | + |
| 282 | + # Store everything in _data to make a dataframe from |
| 283 | + for elapsed_seconds, value in all_xy: |
| 284 | + self._add_single_metric(elapsed_seconds, metric_name, value) |
| 285 | + |
| 286 | + def _add_single_metric(self, timestamp, metric_name, value): |
| 287 | + """Stores a single metric in the _data dict which can be |
| 288 | + converted to a dataframe. |
| 289 | + """ |
| 290 | + # note that this method is built this way to make it possible to |
| 291 | + # support live-refreshing charts in Bokeh at some point in the future. |
| 292 | + self._data['timestamp'].append(timestamp) |
| 293 | + self._data['metric_name'].append(metric_name) |
| 294 | + self._data['value'].append(value) |
0 commit comments