From 2e68a11c88169b33a0a0226c3e4bba39dc86e97d Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Mon, 2 Jul 2018 15:18:19 -0700 Subject: [PATCH 1/2] Fix for issue #273: missing metric_names for training job analytics. --- src/sagemaker/analytics.py | 31 +++++++++++++++++++++++---- src/sagemaker/estimator.py | 2 +- src/sagemaker/utils.py | 10 +++++++++ tests/unit/test_estimator.py | 41 ++++++++++++++++++++++++++++++++++++ tests/unit/test_utils.py | 14 +++++++++++- 5 files changed, 92 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/analytics.py b/src/sagemaker/analytics.py index b81aeeef00..21a2029542 100644 --- a/src/sagemaker/analytics.py +++ b/src/sagemaker/analytics.py @@ -20,7 +20,7 @@ from six import with_metaclass from sagemaker.session import Session -from sagemaker.utils import DeferredError +from sagemaker.utils import DeferredError, extract_name_from_job_arn try: import pandas as pd @@ -201,12 +201,13 @@ class TrainingJobAnalytics(AnalyticsMetricsBase): CLOUDWATCH_NAMESPACE = '/aws/sagemaker/HyperParameterTuningJobs' - def __init__(self, training_job_name, metric_names, sagemaker_session=None): + def __init__(self, training_job_name, metric_names=None, sagemaker_session=None): """Initialize a ``TrainingJobAnalytics`` instance. Args: training_job_name (str): name of the TrainingJob to analyze. - metric_names (list): string names of all the metrics to collect for this training job + metric_names (list, optional): string names of all the metrics to collect for this training job. + If not specified, then it will use all metric names configured for this job. sagemaker_session (sagemaker.session.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is specified using the default AWS configuration chain. @@ -215,7 +216,10 @@ def __init__(self, training_job_name, metric_names, sagemaker_session=None): self._sage_client = sagemaker_session.sagemaker_client self._cloudwatch = sagemaker_session.boto_session.client('cloudwatch') self._training_job_name = training_job_name - self._metric_names = metric_names + if metric_names: + self._metric_names = metric_names + else: + self._metric_names = self._metric_names_for_training_job() self.clear_cache() @property @@ -297,3 +301,22 @@ def _add_single_metric(self, timestamp, metric_name, value): self._data['timestamp'].append(timestamp) self._data['metric_name'].append(metric_name) self._data['value'].append(value) + + def _metric_names_for_training_job(self): + """Helper method to discover the metrics defined for a training job. + """ + # First look up the tuning job + training_description = self._sage_client.describe_training_job(TrainingJobName=self._training_job_name) + tuning_job_arn = training_description.get('TuningJobArn', None) + if not tuning_job_arn: + raise ValueError( + "No metrics available. Training Job Analytics only available through Hyperparameter Tuning Jobs" + ) + tuning_job_name = extract_name_from_job_arn(tuning_job_arn) + tuning_job_description = self._sage_client.describe_hyper_parameter_tuning_job( + HyperParameterTuningJobName=tuning_job_name + ) + training_job_definition = tuning_job_description['TrainingJobDefinition'] + metric_definitions = training_job_definition['AlgorithmSpecification']['MetricDefinitions'] + metric_names = [md['Name'] for md in metric_definitions] + return metric_names diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index c7b03099e7..fba27ed661 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -324,7 +324,7 @@ def training_job_analytics(self): """ if self._current_job_name is None: raise ValueError('Estimator is not associated with a TrainingJob') - return TrainingJobAnalytics(self._current_job_name) + return TrainingJobAnalytics(self._current_job_name, sagemaker_session=self.sagemaker_session) class _TrainingJob(_Job): diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index a6534425c3..fe6dc42264 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -120,6 +120,16 @@ def to_str(value): return str(value) +def extract_name_from_job_arn(arn): + """Returns the name used in the API given a full ARN for a training job + or hyperparameter tuning job. + """ + slash_pos = arn.find('/') + if slash_pos == -1: + raise ValueError("Cannot parse invalid ARN: %s" % arn) + return arn[(slash_pos + 1):] + + class DeferredError(object): """Stores an exception and raises it at a later time anytime this object is accessed in any way. Useful to allow soft-dependencies on imports, diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 2de791b772..b39df5f29d 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -594,6 +594,47 @@ def test_generic_to_deploy(sagemaker_session): assert predictor.sagemaker_session == sagemaker_session +def test_generic_training_job_analytics(sagemaker_session): + sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value={ + 'TuningJobArn': 'arn:aws:sagemaker:us-west-2:968277160000:hyper-parameter-tuning-job/mock-tuner', + 'TrainingStartTime': 1530562991.299, + }) + sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock( + name='describe_hyper_parameter_tuning_job', + return_value={ + 'TrainingJobDefinition': { + "AlgorithmSpecification": { + "TrainingImage": "some-image-url", + "TrainingInputMode": "File", + "MetricDefinitions": [ + { + "Name": "train:loss", + "Regex": "train_loss=([0-9]+\\.[0-9]+)" + }, + { + "Name": "validation:loss", + "Regex": "valid_loss=([0-9]+\\.[0-9]+)" + } + ] + } + } + } + ) + + e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path='s3://bucket/prefix', + sagemaker_session=sagemaker_session) + + with pytest.raises(ValueError) as err: # noqa: F841 + # No training job yet + a = e.training_job_analytics + assert a is not None # This line is never reached + + e.set_hyperparameters(**HYPERPARAMS) + e.fit({'train': 's3://bucket/training-prefix'}) + a = e.training_job_analytics + assert a is not None + + @patch('sagemaker.estimator.LocalSession') @patch('sagemaker.estimator.Session') def test_local_mode(session_class, local_session_class): diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index de40983895..1170d76475 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -17,7 +17,7 @@ import pytest from mock import patch -from sagemaker.utils import get_config_value, name_from_base, to_str, DeferredError +from sagemaker.utils import get_config_value, name_from_base, to_str, DeferredError, extract_name_from_job_arn NAME = 'base_name' @@ -77,3 +77,15 @@ def test_to_str_with_native_string(): def test_to_str_with_unicode_string(): value = u'åñøthér strîng' assert to_str(value) == value + + +def test_name_from_tuning_arn(): + arn = 'arn:aws:sagemaker:us-west-2:968277160000:hyper-parameter-tuning-job/resnet-sgd-tuningjob-11-07-34-11' + name = extract_name_from_job_arn(arn) + assert name == 'resnet-sgd-tuningjob-11-07-34-11' + + +def test_name_from_training_arn(): + arn = 'arn:aws:sagemaker:us-west-2:968277160000:training-job/resnet-sgd-tuningjob-11-22-38-46-002-2927640b' + name = extract_name_from_job_arn(arn) + assert name == 'resnet-sgd-tuningjob-11-22-38-46-002-2927640b' From ee98a1579b32ba6e6279a10101a3948862a357de Mon Sep 17 00:00:00 2001 From: Leo Dirac Date: Mon, 2 Jul 2018 15:43:58 -0700 Subject: [PATCH 2/2] Fixup: adding note to changelog. --- CHANGELOG.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6be0fd089d..8f8f2984ed 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -5,6 +5,7 @@ CHANGELOG 1.5.3dev ======== +* bug-fix: Can create TrainingJobAnalytics object without specifying metric_names. * bug-fix: Session: include role path in ``get_execution_role()`` result 1.5.2