Skip to content

Commit 68dea63

Browse files
leopdnadiaya
authored andcommitted
Make the metric_names list optional in TrainingJobAnalytics object. (#274)
* Fix for issue #273: missing metric_names for training job analytics. * Fixup: adding note to changelog.
1 parent 1824610 commit 68dea63

File tree

6 files changed

+93
-6
lines changed

6 files changed

+93
-6
lines changed

CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ CHANGELOG
55
1.5.3dev
66
========
77

8+
* bug-fix: Can create TrainingJobAnalytics object without specifying metric_names.
89
* bug-fix: Session: include role path in ``get_execution_role()`` result
910
* bug-fix: Local Mode: fix RuntimeError handling
1011

src/sagemaker/analytics.py

+27-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from six import with_metaclass
2121

2222
from sagemaker.session import Session
23-
from sagemaker.utils import DeferredError
23+
from sagemaker.utils import DeferredError, extract_name_from_job_arn
2424

2525
try:
2626
import pandas as pd
@@ -201,12 +201,13 @@ class TrainingJobAnalytics(AnalyticsMetricsBase):
201201

202202
CLOUDWATCH_NAMESPACE = '/aws/sagemaker/HyperParameterTuningJobs'
203203

204-
def __init__(self, training_job_name, metric_names, sagemaker_session=None):
204+
def __init__(self, training_job_name, metric_names=None, sagemaker_session=None):
205205
"""Initialize a ``TrainingJobAnalytics`` instance.
206206
207207
Args:
208208
training_job_name (str): name of the TrainingJob to analyze.
209-
metric_names (list): string names of all the metrics to collect for this training job
209+
metric_names (list, optional): string names of all the metrics to collect for this training job.
210+
If not specified, then it will use all metric names configured for this job.
210211
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
211212
Amazon SageMaker APIs and any other AWS services needed. If not specified, one is specified
212213
using the default AWS configuration chain.
@@ -215,7 +216,10 @@ def __init__(self, training_job_name, metric_names, sagemaker_session=None):
215216
self._sage_client = sagemaker_session.sagemaker_client
216217
self._cloudwatch = sagemaker_session.boto_session.client('cloudwatch')
217218
self._training_job_name = training_job_name
218-
self._metric_names = metric_names
219+
if metric_names:
220+
self._metric_names = metric_names
221+
else:
222+
self._metric_names = self._metric_names_for_training_job()
219223
self.clear_cache()
220224

221225
@property
@@ -297,3 +301,22 @@ def _add_single_metric(self, timestamp, metric_name, value):
297301
self._data['timestamp'].append(timestamp)
298302
self._data['metric_name'].append(metric_name)
299303
self._data['value'].append(value)
304+
305+
def _metric_names_for_training_job(self):
306+
"""Helper method to discover the metrics defined for a training job.
307+
"""
308+
# First look up the tuning job
309+
training_description = self._sage_client.describe_training_job(TrainingJobName=self._training_job_name)
310+
tuning_job_arn = training_description.get('TuningJobArn', None)
311+
if not tuning_job_arn:
312+
raise ValueError(
313+
"No metrics available. Training Job Analytics only available through Hyperparameter Tuning Jobs"
314+
)
315+
tuning_job_name = extract_name_from_job_arn(tuning_job_arn)
316+
tuning_job_description = self._sage_client.describe_hyper_parameter_tuning_job(
317+
HyperParameterTuningJobName=tuning_job_name
318+
)
319+
training_job_definition = tuning_job_description['TrainingJobDefinition']
320+
metric_definitions = training_job_definition['AlgorithmSpecification']['MetricDefinitions']
321+
metric_names = [md['Name'] for md in metric_definitions]
322+
return metric_names

src/sagemaker/estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def training_job_analytics(self):
324324
"""
325325
if self._current_job_name is None:
326326
raise ValueError('Estimator is not associated with a TrainingJob')
327-
return TrainingJobAnalytics(self._current_job_name)
327+
return TrainingJobAnalytics(self._current_job_name, sagemaker_session=self.sagemaker_session)
328328

329329

330330
class _TrainingJob(_Job):

src/sagemaker/utils.py

+10
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,16 @@ def to_str(value):
120120
return str(value)
121121

122122

123+
def extract_name_from_job_arn(arn):
124+
"""Returns the name used in the API given a full ARN for a training job
125+
or hyperparameter tuning job.
126+
"""
127+
slash_pos = arn.find('/')
128+
if slash_pos == -1:
129+
raise ValueError("Cannot parse invalid ARN: %s" % arn)
130+
return arn[(slash_pos + 1):]
131+
132+
123133
class DeferredError(object):
124134
"""Stores an exception and raises it at a later time anytime this
125135
object is accessed in any way. Useful to allow soft-dependencies on imports,

tests/unit/test_estimator.py

+41
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,47 @@ def test_generic_to_deploy(sagemaker_session):
594594
assert predictor.sagemaker_session == sagemaker_session
595595

596596

597+
def test_generic_training_job_analytics(sagemaker_session):
598+
sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value={
599+
'TuningJobArn': 'arn:aws:sagemaker:us-west-2:968277160000:hyper-parameter-tuning-job/mock-tuner',
600+
'TrainingStartTime': 1530562991.299,
601+
})
602+
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
603+
name='describe_hyper_parameter_tuning_job',
604+
return_value={
605+
'TrainingJobDefinition': {
606+
"AlgorithmSpecification": {
607+
"TrainingImage": "some-image-url",
608+
"TrainingInputMode": "File",
609+
"MetricDefinitions": [
610+
{
611+
"Name": "train:loss",
612+
"Regex": "train_loss=([0-9]+\\.[0-9]+)"
613+
},
614+
{
615+
"Name": "validation:loss",
616+
"Regex": "valid_loss=([0-9]+\\.[0-9]+)"
617+
}
618+
]
619+
}
620+
}
621+
}
622+
)
623+
624+
e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path='s3://bucket/prefix',
625+
sagemaker_session=sagemaker_session)
626+
627+
with pytest.raises(ValueError) as err: # noqa: F841
628+
# No training job yet
629+
a = e.training_job_analytics
630+
assert a is not None # This line is never reached
631+
632+
e.set_hyperparameters(**HYPERPARAMS)
633+
e.fit({'train': 's3://bucket/training-prefix'})
634+
a = e.training_job_analytics
635+
assert a is not None
636+
637+
597638
@patch('sagemaker.estimator.LocalSession')
598639
@patch('sagemaker.estimator.Session')
599640
def test_local_mode(session_class, local_session_class):

tests/unit/test_utils.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import pytest
1818
from mock import patch
1919

20-
from sagemaker.utils import get_config_value, name_from_base, to_str, DeferredError
20+
from sagemaker.utils import get_config_value, name_from_base, to_str, DeferredError, extract_name_from_job_arn
2121

2222
NAME = 'base_name'
2323

@@ -77,3 +77,15 @@ def test_to_str_with_native_string():
7777
def test_to_str_with_unicode_string():
7878
value = u'åñøthér strîng'
7979
assert to_str(value) == value
80+
81+
82+
def test_name_from_tuning_arn():
83+
arn = 'arn:aws:sagemaker:us-west-2:968277160000:hyper-parameter-tuning-job/resnet-sgd-tuningjob-11-07-34-11'
84+
name = extract_name_from_job_arn(arn)
85+
assert name == 'resnet-sgd-tuningjob-11-07-34-11'
86+
87+
88+
def test_name_from_training_arn():
89+
arn = 'arn:aws:sagemaker:us-west-2:968277160000:training-job/resnet-sgd-tuningjob-11-22-38-46-002-2927640b'
90+
name = extract_name_from_job_arn(arn)
91+
assert name == 'resnet-sgd-tuningjob-11-22-38-46-002-2927640b'

0 commit comments

Comments
 (0)