Skip to content

Commit b15c05a

Browse files
authored
change: make start time, end time and period configurable in sagemaker.analytics.TrainingJobAnalytics (#730)
* Make start time, end time and period configurable in analytics.TrainingJobAnalytics Creating an TrainingJobAnalytics object fails if the training job has too many data points in the specified metrics. Make start time, end time and period configurable so the caller can get around this limit - https://docs.aws.amazon.com/AmazonCloudWatch/latest/APIReference/API_GetMetricStatistics.html Original issue: #701 * Add analytics integ test to the TensorFlow script mode minist test * Minor changes due to PR comments and to make flake8 happy * More minor changes * One more minor change
1 parent b6e1993 commit b15c05a

File tree

3 files changed

+34
-4
lines changed

3 files changed

+34
-4
lines changed

src/sagemaker/analytics.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
# Any subsequent attempt to use pandas will raise the ImportError
3030
pd = DeferredError(e)
3131

32+
METRICS_PERIOD_DEFAULT = 60 # seconds
33+
3234

3335
class AnalyticsMetricsBase(with_metaclass(ABCMeta, object)):
3436
"""Base class for tuning job or training job analytics classes.
@@ -201,7 +203,8 @@ class TrainingJobAnalytics(AnalyticsMetricsBase):
201203

202204
CLOUDWATCH_NAMESPACE = '/aws/sagemaker/TrainingJobs'
203205

204-
def __init__(self, training_job_name, metric_names=None, sagemaker_session=None):
206+
def __init__(self, training_job_name, metric_names=None, sagemaker_session=None,
207+
start_time=None, end_time=None, period=None):
205208
"""Initialize a ``TrainingJobAnalytics`` instance.
206209
207210
Args:
@@ -216,6 +219,10 @@ def __init__(self, training_job_name, metric_names=None, sagemaker_session=None)
216219
self._sage_client = sagemaker_session.sagemaker_client
217220
self._cloudwatch = sagemaker_session.boto_session.client('cloudwatch')
218221
self._training_job_name = training_job_name
222+
self._start_time = start_time
223+
self._end_time = end_time
224+
self._period = period or METRICS_PERIOD_DEFAULT
225+
219226
if metric_names:
220227
self._metric_names = metric_names
221228
else:
@@ -245,13 +252,15 @@ def _determine_timeinterval(self):
245252
covering the interval of the training job
246253
"""
247254
description = self._sage_client.describe_training_job(TrainingJobName=self.name)
248-
start_time = description[u'TrainingStartTime'] # datetime object
255+
start_time = self._start_time or description[u'TrainingStartTime'] # datetime object
249256
# Incrementing end time by 1 min since CloudWatch drops seconds before finding the logs.
250257
# This results in logs being searched in the time range in which the correct log line was not present.
251258
# Example - Log time - 2018-10-22 08:25:55
252259
# Here calculated end time would also be 2018-10-22 08:25:55 (without 1 min addition)
253260
# CW will consider end time as 2018-10-22 08:25 and will not be able to search the correct log.
254-
end_time = description.get(u'TrainingEndTime', datetime.datetime.utcnow()) + datetime.timedelta(minutes=1)
261+
end_time = self._end_time or description.get(
262+
u'TrainingEndTime', datetime.datetime.utcnow()) + datetime.timedelta(minutes=1)
263+
255264
return {
256265
'start_time': start_time,
257266
'end_time': end_time,
@@ -276,7 +285,7 @@ def _fetch_metric(self, metric_name):
276285
],
277286
'StartTime': self._time_interval['start_time'],
278287
'EndTime': self._time_interval['end_time'],
279-
'Period': 60,
288+
'Period': self._period,
280289
'Statistics': ['Average'],
281290
}
282291
raw_cwm_data = self._cloudwatch.get_metric_statistics(**request)['Datapoints']

tests/integ/test_tf_script_mode.py

+4
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_mnist(sagemaker_session, instance_type):
4747
sagemaker_session=sagemaker_session,
4848
py_version='py3',
4949
framework_version=TensorFlow.LATEST_VERSION,
50+
metric_definitions=[{'Name': 'train:global_steps', 'Regex': r'global_step\/sec:\s(.*)'}],
5051
base_job_name='test-tf-sm-mnist')
5152
inputs = estimator.sagemaker_session.upload_data(
5253
path=os.path.join(RESOURCE_PATH, 'data'),
@@ -56,6 +57,9 @@ def test_mnist(sagemaker_session, instance_type):
5657
estimator.fit(inputs)
5758
_assert_s3_files_exist(estimator.model_dir,
5859
['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta'])
60+
df = estimator.training_job_analytics.dataframe()
61+
print(df)
62+
assert df.size > 0
5963

6064

6165
def test_server_side_encryption(sagemaker_session):

tests/unit/test_analytics.py

+17
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,20 @@ def test_trainer_dataframe():
245245
trainer.export_csv(tmp_name)
246246
assert os.path.isfile(tmp_name)
247247
os.unlink(tmp_name)
248+
249+
250+
def test_start_time_end_time_and_period_specified():
251+
describe_training_result = {
252+
'TrainingStartTime': datetime.datetime(2018, 5, 16, 1, 2, 3),
253+
'TrainingEndTime': datetime.datetime(2018, 5, 16, 5, 6, 7),
254+
}
255+
session = create_sagemaker_session(describe_training_result)
256+
start_time = datetime.datetime(2018, 5, 16, 1, 3, 4)
257+
end_time = datetime.datetime(2018, 5, 16, 5, 1, 1)
258+
period = 300
259+
trainer = TrainingJobAnalytics('my-training-job', ['metric'],
260+
sagemaker_session=session, start_time=start_time, end_time=end_time, period=period)
261+
262+
assert trainer._time_interval['start_time'] == start_time
263+
assert trainer._time_interval['end_time'] == end_time
264+
assert trainer._period == period

0 commit comments

Comments
 (0)