20
20
from six import with_metaclass
21
21
22
22
from sagemaker .session import Session
23
- from sagemaker .utils import DeferredError
23
+ from sagemaker .utils import DeferredError , extract_name_from_job_arn
24
24
25
25
try :
26
26
import pandas as pd
@@ -201,12 +201,13 @@ class TrainingJobAnalytics(AnalyticsMetricsBase):
201
201
202
202
CLOUDWATCH_NAMESPACE = '/aws/sagemaker/HyperParameterTuningJobs'
203
203
204
- def __init__ (self , training_job_name , metric_names , sagemaker_session = None ):
204
+ def __init__ (self , training_job_name , metric_names = None , sagemaker_session = None ):
205
205
"""Initialize a ``TrainingJobAnalytics`` instance.
206
206
207
207
Args:
208
208
training_job_name (str): name of the TrainingJob to analyze.
209
- metric_names (list): string names of all the metrics to collect for this training job
209
+ metric_names (list, optional): string names of all the metrics to collect for this training job.
210
+ If not specified, then it will use all metric names configured for this job.
210
211
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
211
212
Amazon SageMaker APIs and any other AWS services needed. If not specified, one is specified
212
213
using the default AWS configuration chain.
@@ -215,7 +216,10 @@ def __init__(self, training_job_name, metric_names, sagemaker_session=None):
215
216
self ._sage_client = sagemaker_session .sagemaker_client
216
217
self ._cloudwatch = sagemaker_session .boto_session .client ('cloudwatch' )
217
218
self ._training_job_name = training_job_name
218
- self ._metric_names = metric_names
219
+ if metric_names :
220
+ self ._metric_names = metric_names
221
+ else :
222
+ self ._metric_names = self ._metric_names_for_training_job ()
219
223
self .clear_cache ()
220
224
221
225
@property
@@ -297,3 +301,22 @@ def _add_single_metric(self, timestamp, metric_name, value):
297
301
self ._data ['timestamp' ].append (timestamp )
298
302
self ._data ['metric_name' ].append (metric_name )
299
303
self ._data ['value' ].append (value )
304
+
305
+ def _metric_names_for_training_job (self ):
306
+ """Helper method to discover the metrics defined for a training job.
307
+ """
308
+ # First look up the tuning job
309
+ training_description = self ._sage_client .describe_training_job (TrainingJobName = self ._training_job_name )
310
+ tuning_job_arn = training_description .get ('TuningJobArn' , None )
311
+ if not tuning_job_arn :
312
+ raise ValueError (
313
+ "No metrics available. Training Job Analytics only available through Hyperparameter Tuning Jobs"
314
+ )
315
+ tuning_job_name = extract_name_from_job_arn (tuning_job_arn )
316
+ tuning_job_description = self ._sage_client .describe_hyper_parameter_tuning_job (
317
+ HyperParameterTuningJobName = tuning_job_name
318
+ )
319
+ training_job_definition = tuning_job_description ['TrainingJobDefinition' ]
320
+ metric_definitions = training_job_definition ['AlgorithmSpecification' ]['MetricDefinitions' ]
321
+ metric_names = [md ['Name' ] for md in metric_definitions ]
322
+ return metric_names
0 commit comments