Skip to content

Commit fc2fcc0

Browse files
committed
Make start time, end time and period configurable in analytics.TrainingJobAnalytics
Creating an TrainingJobAnalytics object fails if the training job has too many data points in the specified metrics. Make start time, end time and period configurable so the caller can get around this limit - https://docs.aws.amazon.com/AmazonCloudWatch/latest/APIReference/API_GetMetricStatistics.html Original issue: aws#701
1 parent 475e051 commit fc2fcc0

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

src/sagemaker/analytics.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,8 @@ class TrainingJobAnalytics(AnalyticsMetricsBase):
201201

202202
CLOUDWATCH_NAMESPACE = '/aws/sagemaker/TrainingJobs'
203203

204-
def __init__(self, training_job_name, metric_names=None, sagemaker_session=None):
204+
def __init__(self, training_job_name, metric_names=None, sagemaker_session=None,
205+
start_time=None, end_time=None, period=None):
205206
"""Initialize a ``TrainingJobAnalytics`` instance.
206207
207208
Args:
@@ -216,6 +217,10 @@ def __init__(self, training_job_name, metric_names=None, sagemaker_session=None)
216217
self._sage_client = sagemaker_session.sagemaker_client
217218
self._cloudwatch = sagemaker_session.boto_session.client('cloudwatch')
218219
self._training_job_name = training_job_name
220+
self._start_time = start_time
221+
self._end_time = end_time
222+
self._period = period if period else 60
223+
219224
if metric_names:
220225
self._metric_names = metric_names
221226
else:
@@ -245,13 +250,15 @@ def _determine_timeinterval(self):
245250
covering the interval of the training job
246251
"""
247252
description = self._sage_client.describe_training_job(TrainingJobName=self.name)
248-
start_time = description[u'TrainingStartTime'] # datetime object
253+
start_time = self._start_time if self._start_time else description[u'TrainingStartTime'] # datetime object
249254
# Incrementing end time by 1 min since CloudWatch drops seconds before finding the logs.
250255
# This results in logs being searched in the time range in which the correct log line was not present.
251256
# Example - Log time - 2018-10-22 08:25:55
252257
# Here calculated end time would also be 2018-10-22 08:25:55 (without 1 min addition)
253258
# CW will consider end time as 2018-10-22 08:25 and will not be able to search the correct log.
254-
end_time = description.get(u'TrainingEndTime', datetime.datetime.utcnow()) + datetime.timedelta(minutes=1)
259+
end_time = self._end_time if self._end_time else description.get(
260+
u'TrainingEndTime', datetime.datetime.utcnow()) + datetime.timedelta(minutes=1)
261+
255262
return {
256263
'start_time': start_time,
257264
'end_time': end_time,
@@ -276,7 +283,7 @@ def _fetch_metric(self, metric_name):
276283
],
277284
'StartTime': self._time_interval['start_time'],
278285
'EndTime': self._time_interval['end_time'],
279-
'Period': 60,
286+
'Period': self._period,
280287
'Statistics': ['Average'],
281288
}
282289
raw_cwm_data = self._cloudwatch.get_metric_statistics(**request)['Datapoints']

tests/unit/test_analytics.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,8 @@ def test_trainer_name():
197197
'TrainingEndTime': datetime.datetime(2018, 5, 16, 5, 6, 7),
198198
}
199199
session = create_sagemaker_session(describe_training_result)
200-
trainer = TrainingJobAnalytics("my-training-job", ["metric"], sagemaker_session=session)
200+
trainer = TrainingJobAnalytics(training_job_name="my-training-job", metric_names=["metric"],
201+
sagemaker_session=session)
201202
assert trainer.name == "my-training-job"
202203
assert str(trainer).find("my-training-job") != -1
203204

@@ -231,7 +232,8 @@ def _metric_stats_results():
231232
def test_trainer_dataframe():
232233
session = create_sagemaker_session(describe_training_result=_describe_training_result(),
233234
metric_stats_results=_metric_stats_results())
234-
trainer = TrainingJobAnalytics("my-training-job", ["train:acc"], sagemaker_session=session)
235+
trainer = TrainingJobAnalytics(training_job_name="my-training-job", metric_names=["train:acc"],
236+
sagemaker_session=session)
235237

236238
df = trainer.dataframe()
237239
assert df is not None
@@ -245,3 +247,20 @@ def test_trainer_dataframe():
245247
trainer.export_csv(tmp_name)
246248
assert os.path.isfile(tmp_name)
247249
os.unlink(tmp_name)
250+
251+
252+
def test_start_time_end_time_and_period_specified():
253+
describe_training_result = {
254+
'TrainingStartTime': datetime.datetime(2018, 5, 16, 1, 2, 3),
255+
'TrainingEndTime': datetime.datetime(2018, 5, 16, 5, 6, 7),
256+
}
257+
session = create_sagemaker_session(describe_training_result)
258+
start_time = datetime.datetime(2018, 5, 16, 1, 3, 4)
259+
end_time = datetime.datetime(2018, 5, 16, 5, 1, 1)
260+
period = 300
261+
trainer = TrainingJobAnalytics(training_job_name="my-training-job", metric_names=["metric"],
262+
sagemaker_session=session, start_time=start_time, end_time=end_time, period=period)
263+
264+
assert trainer._time_interval['start_time'] == start_time
265+
assert trainer._time_interval['end_time'] == end_time
266+
assert trainer._period == period

0 commit comments

Comments
 (0)