Skip to content

Make the metric_names list optional in TrainingJobAnalytics object. #274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 5, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ CHANGELOG
1.5.3dev
========

* bug-fix: Can create TrainingJobAnalytics object without specifying metric_names.
* bug-fix: Session: include role path in ``get_execution_role()`` result

1.5.2
Expand Down
31 changes: 27 additions & 4 deletions src/sagemaker/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from six import with_metaclass

from sagemaker.session import Session
from sagemaker.utils import DeferredError
from sagemaker.utils import DeferredError, extract_name_from_job_arn

try:
import pandas as pd
Expand Down Expand Up @@ -201,12 +201,13 @@ class TrainingJobAnalytics(AnalyticsMetricsBase):

CLOUDWATCH_NAMESPACE = '/aws/sagemaker/HyperParameterTuningJobs'

def __init__(self, training_job_name, metric_names, sagemaker_session=None):
def __init__(self, training_job_name, metric_names=None, sagemaker_session=None):
"""Initialize a ``TrainingJobAnalytics`` instance.

Args:
training_job_name (str): name of the TrainingJob to analyze.
metric_names (list): string names of all the metrics to collect for this training job
metric_names (list, optional): string names of all the metrics to collect for this training job.
If not specified, then it will use all metric names configured for this job.
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
Amazon SageMaker APIs and any other AWS services needed. If not specified, one is specified
using the default AWS configuration chain.
Expand All @@ -215,7 +216,10 @@ def __init__(self, training_job_name, metric_names, sagemaker_session=None):
self._sage_client = sagemaker_session.sagemaker_client
self._cloudwatch = sagemaker_session.boto_session.client('cloudwatch')
self._training_job_name = training_job_name
self._metric_names = metric_names
if metric_names:
self._metric_names = metric_names
else:
self._metric_names = self._metric_names_for_training_job()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be done in one line:

self._metric_names = metric_names or self._metric_names_for_training_job()

self.clear_cache()

@property
Expand Down Expand Up @@ -297,3 +301,22 @@ def _add_single_metric(self, timestamp, metric_name, value):
self._data['timestamp'].append(timestamp)
self._data['metric_name'].append(metric_name)
self._data['value'].append(value)

def _metric_names_for_training_job(self):
"""Helper method to discover the metrics defined for a training job.
"""
# First look up the tuning job
training_description = self._sage_client.describe_training_job(TrainingJobName=self._training_job_name)
tuning_job_arn = training_description.get('TuningJobArn', None)
if not tuning_job_arn:
raise ValueError(
"No metrics available. Training Job Analytics only available through Hyperparameter Tuning Jobs"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

single quotes

)
tuning_job_name = extract_name_from_job_arn(tuning_job_arn)
tuning_job_description = self._sage_client.describe_hyper_parameter_tuning_job(
HyperParameterTuningJobName=tuning_job_name
)
training_job_definition = tuning_job_description['TrainingJobDefinition']
metric_definitions = training_job_definition['AlgorithmSpecification']['MetricDefinitions']
metric_names = [md['Name'] for md in metric_definitions]
return metric_names
2 changes: 1 addition & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def training_job_analytics(self):
"""
if self._current_job_name is None:
raise ValueError('Estimator is not associated with a TrainingJob')
return TrainingJobAnalytics(self._current_job_name)
return TrainingJobAnalytics(self._current_job_name, sagemaker_session=self.sagemaker_session)


class _TrainingJob(_Job):
Expand Down
10 changes: 10 additions & 0 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ def to_str(value):
return str(value)


def extract_name_from_job_arn(arn):
"""Returns the name used in the API given a full ARN for a training job
or hyperparameter tuning job.
"""
slash_pos = arn.find('/')
if slash_pos == -1:
raise ValueError("Cannot parse invalid ARN: %s" % arn)
return arn[(slash_pos + 1):]


class DeferredError(object):
"""Stores an exception and raises it at a later time anytime this
object is accessed in any way. Useful to allow soft-dependencies on imports,
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,47 @@ def test_generic_to_deploy(sagemaker_session):
assert predictor.sagemaker_session == sagemaker_session


def test_generic_training_job_analytics(sagemaker_session):
sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value={
'TuningJobArn': 'arn:aws:sagemaker:us-west-2:968277160000:hyper-parameter-tuning-job/mock-tuner',
'TrainingStartTime': 1530562991.299,
})
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
name='describe_hyper_parameter_tuning_job',
return_value={
'TrainingJobDefinition': {
"AlgorithmSpecification": {
"TrainingImage": "some-image-url",
"TrainingInputMode": "File",
"MetricDefinitions": [
{
"Name": "train:loss",
"Regex": "train_loss=([0-9]+\\.[0-9]+)"
},
{
"Name": "validation:loss",
"Regex": "valid_loss=([0-9]+\\.[0-9]+)"
}
]
}
}
}
)

e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path='s3://bucket/prefix',
sagemaker_session=sagemaker_session)

with pytest.raises(ValueError) as err: # noqa: F841
# No training job yet
a = e.training_job_analytics
assert a is not None # This line is never reached

e.set_hyperparameters(**HYPERPARAMS)
e.fit({'train': 's3://bucket/training-prefix'})
a = e.training_job_analytics
assert a is not None


@patch('sagemaker.estimator.LocalSession')
@patch('sagemaker.estimator.Session')
def test_local_mode(session_class, local_session_class):
Expand Down
14 changes: 13 additions & 1 deletion tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
from mock import patch

from sagemaker.utils import get_config_value, name_from_base, to_str, DeferredError
from sagemaker.utils import get_config_value, name_from_base, to_str, DeferredError, extract_name_from_job_arn

NAME = 'base_name'

Expand Down Expand Up @@ -77,3 +77,15 @@ def test_to_str_with_native_string():
def test_to_str_with_unicode_string():
value = u'åñøthér strîng'
assert to_str(value) == value


def test_name_from_tuning_arn():
arn = 'arn:aws:sagemaker:us-west-2:968277160000:hyper-parameter-tuning-job/resnet-sgd-tuningjob-11-07-34-11'
name = extract_name_from_job_arn(arn)
assert name == 'resnet-sgd-tuningjob-11-07-34-11'


def test_name_from_training_arn():
arn = 'arn:aws:sagemaker:us-west-2:968277160000:training-job/resnet-sgd-tuningjob-11-22-38-46-002-2927640b'
name = extract_name_from_job_arn(arn)
assert name == 'resnet-sgd-tuningjob-11-22-38-46-002-2927640b'