Skip to content

change: make start time, end time and period configurable in sagemaker.analytics.TrainingJobAnalytics #730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 2, 2019
17 changes: 13 additions & 4 deletions src/sagemaker/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
# Any subsequent attempt to use pandas will raise the ImportError
pd = DeferredError(e)

METRICS_PERIOD_DEFAULT = 60 # seconds


class AnalyticsMetricsBase(with_metaclass(ABCMeta, object)):
"""Base class for tuning job or training job analytics classes.
Expand Down Expand Up @@ -201,7 +203,8 @@ class TrainingJobAnalytics(AnalyticsMetricsBase):

CLOUDWATCH_NAMESPACE = '/aws/sagemaker/TrainingJobs'

def __init__(self, training_job_name, metric_names=None, sagemaker_session=None):
def __init__(self, training_job_name, metric_names=None, sagemaker_session=None,
start_time=None, end_time=None, period=None):
"""Initialize a ``TrainingJobAnalytics`` instance.
Args:
Expand All @@ -216,6 +219,10 @@ def __init__(self, training_job_name, metric_names=None, sagemaker_session=None)
self._sage_client = sagemaker_session.sagemaker_client
self._cloudwatch = sagemaker_session.boto_session.client('cloudwatch')
self._training_job_name = training_job_name
self._start_time = start_time
self._end_time = end_time
self._period = period or METRICS_PERIOD_DEFAULT

if metric_names:
self._metric_names = metric_names
else:
Expand Down Expand Up @@ -245,13 +252,15 @@ def _determine_timeinterval(self):
covering the interval of the training job
"""
description = self._sage_client.describe_training_job(TrainingJobName=self.name)
start_time = description[u'TrainingStartTime'] # datetime object
start_time = self._start_time or description[u'TrainingStartTime'] # datetime object
# Incrementing end time by 1 min since CloudWatch drops seconds before finding the logs.
# This results in logs being searched in the time range in which the correct log line was not present.
# Example - Log time - 2018-10-22 08:25:55
# Here calculated end time would also be 2018-10-22 08:25:55 (without 1 min addition)
# CW will consider end time as 2018-10-22 08:25 and will not be able to search the correct log.
end_time = description.get(u'TrainingEndTime', datetime.datetime.utcnow()) + datetime.timedelta(minutes=1)
end_time = self._end_time or description.get(
u'TrainingEndTime', datetime.datetime.utcnow()) + datetime.timedelta(minutes=1)

return {
'start_time': start_time,
'end_time': end_time,
Expand All @@ -276,7 +285,7 @@ def _fetch_metric(self, metric_name):
],
'StartTime': self._time_interval['start_time'],
'EndTime': self._time_interval['end_time'],
'Period': 60,
'Period': self._period,
'Statistics': ['Average'],
}
raw_cwm_data = self._cloudwatch.get_metric_statistics(**request)['Datapoints']
Expand Down
4 changes: 4 additions & 0 deletions tests/integ/test_tf_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_mnist(sagemaker_session, instance_type):
sagemaker_session=sagemaker_session,
py_version='py3',
framework_version=TensorFlow.LATEST_VERSION,
metric_definitions=[{'Name': 'train:global_steps', 'Regex': r'global_step\/sec:\s(.*)'}],
base_job_name='test-tf-sm-mnist')
inputs = estimator.sagemaker_session.upload_data(
path=os.path.join(RESOURCE_PATH, 'data'),
Expand All @@ -56,6 +57,9 @@ def test_mnist(sagemaker_session, instance_type):
estimator.fit(inputs)
_assert_s3_files_exist(estimator.model_dir,
['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta'])
df = estimator.training_job_analytics.dataframe()
print(df)
assert df.size > 0


def test_server_side_encryption(sagemaker_session):
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,20 @@ def test_trainer_dataframe():
trainer.export_csv(tmp_name)
assert os.path.isfile(tmp_name)
os.unlink(tmp_name)


def test_start_time_end_time_and_period_specified():
describe_training_result = {
'TrainingStartTime': datetime.datetime(2018, 5, 16, 1, 2, 3),
'TrainingEndTime': datetime.datetime(2018, 5, 16, 5, 6, 7),
}
session = create_sagemaker_session(describe_training_result)
start_time = datetime.datetime(2018, 5, 16, 1, 3, 4)
end_time = datetime.datetime(2018, 5, 16, 5, 1, 1)
period = 300
trainer = TrainingJobAnalytics('my-training-job', ['metric'],
sagemaker_session=session, start_time=start_time, end_time=end_time, period=period)

assert trainer._time_interval['start_time'] == start_time
assert trainer._time_interval['end_time'] == end_time
assert trainer._period == period