Skip to content

Commit 6412991

Browse files
SifeiLilaurenyu
authored andcommitted
Support MetricDefinitions for general training jobs (aws#484)
1 parent 97bbc39 commit 6412991

File tree

11 files changed

+67
-16
lines changed

11 files changed

+67
-16
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ CHANGELOG
1313
* feature: HyperparameterTuner: add support for Automatic Model Tuning's Warm Start Jobs
1414
* feature: HyperparameterTuner: Make input channels optional
1515
* feature: Add support for Chainer 5.0
16+
* feature: Estimator: add support for MetricDefinitions
1617

1718
1.14.2
1819
======

README.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,25 @@ Here is an end to end example of how to use a SageMaker Estimator:
170170
# Tears down the SageMaker endpoint
171171
mxnet_estimator.delete_endpoint()
172172
173+
Training Metrics
174+
~~~~~~~~~~~~~~~~
175+
The SageMaker Python SDK allows you to specify a name and a regular expression for metrics you want to track for training.
176+
A regular expression (regex) matches what is in the training algorithm logs, like a search function.
177+
Here is an example of how to define metrics:
178+
179+
.. code:: python
180+
181+
# Configure an BYO Estimator with metric definitions (no training happens yet)
182+
byo_estimator = Estimator(image_name=image_name,
183+
role='SageMakerRole', train_instance_count=1,
184+
train_instance_type='ml.c4.xlarge',
185+
sagemaker_session=sagemaker_session,
186+
metric_definitions=[{'Name': 'test:msd', 'Regex': '#quality_metric: host=\S+, test msd <loss>=(\S+)'},
187+
{'Name': 'test:ssd', 'Regex': '#quality_metric: host=\S+, test ssd <loss>=(\S+)'}])
188+
189+
All Amazon SageMaker algorithms come with built-in support for metrics.
190+
You can go to `the AWS documentation <https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html>`__ for more details about built-in metrics of each Amazon SageMaker algorithm.
191+
173192
Local Mode
174193
~~~~~~~~~~
175194

src/sagemaker/estimator.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ class EstimatorBase(with_metaclass(ABCMeta, object)):
5050
def __init__(self, role, train_instance_count, train_instance_type,
5151
train_volume_size=30, train_volume_kms_key=None, train_max_run=24 * 60 * 60, input_mode='File',
5252
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None,
53-
subnets=None, security_group_ids=None, model_uri=None, model_channel_name='model'):
53+
subnets=None, security_group_ids=None, model_uri=None, model_channel_name='model',
54+
metric_definitions=None):
5455
"""Initialize an ``EstimatorBase`` instance.
5556
5657
Args:
@@ -97,6 +98,10 @@ def __init__(self, role, train_instance_count, train_instance_type,
9798
9899
More information: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html#td-deserialization
99100
model_channel_name (str): Name of the channel where 'model_uri' will be downloaded (default: 'model').
101+
metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) used to evaluate the
102+
training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for
103+
the regular expression used to extract the metric from the logs. This should be defined only
104+
for jobs that don't use an Amazon algorithm.
100105
"""
101106
self.role = role
102107
self.train_instance_count = train_instance_count
@@ -106,6 +111,7 @@ def __init__(self, role, train_instance_count, train_instance_type,
106111
self.train_max_run = train_max_run
107112
self.input_mode = input_mode
108113
self.tags = tags
114+
self.metric_definitions = metric_definitions
109115
self.model_uri = model_uri
110116
self.model_channel_name = model_channel_name
111117

@@ -324,6 +330,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
324330
init_params['hyperparameters'] = job_details['HyperParameters']
325331
init_params['image'] = job_details['AlgorithmSpecification']['TrainingImage']
326332

333+
if 'MetricDefinitons' in job_details['AlgorithmSpecification']:
334+
init_params['metric_definitions'] = job_details['AlgorithmSpecification']['MetricsDefinition']
335+
327336
subnets, security_group_ids = vpc_utils.from_dict(job_details.get(vpc_utils.VPC_CONFIG_KEY))
328337
if subnets:
329338
init_params['subnets'] = subnets
@@ -441,7 +450,7 @@ def start_new(cls, estimator, inputs):
441450
job_name=estimator._current_job_name, output_config=config['output_config'],
442451
resource_config=config['resource_config'], vpc_config=config['vpc_config'],
443452
hyperparameters=hyperparameters, stop_condition=config['stop_condition'],
444-
tags=estimator.tags)
453+
tags=estimator.tags, metric_definitions=estimator.metric_definitions)
445454

446455
return cls(estimator.sagemaker_session, estimator._current_job_name)
447456

@@ -466,7 +475,7 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
466475
train_volume_size=30, train_volume_kms_key=None, train_max_run=24 * 60 * 60,
467476
input_mode='File', output_path=None, output_kms_key=None, base_job_name=None,
468477
sagemaker_session=None, hyperparameters=None, tags=None, subnets=None, security_group_ids=None,
469-
model_uri=None, model_channel_name='model'):
478+
model_uri=None, model_channel_name='model', metric_definitions=None):
470479
"""Initialize an ``Estimator`` instance.
471480
472481
Args:
@@ -517,14 +526,18 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
517526
518527
More information: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html#td-deserialization
519528
model_channel_name (str): Name of the channel where 'model_uri' will be downloaded (default: 'model').
529+
metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) used to evaluate the
530+
training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for
531+
the regular expression used to extract the metric from the logs. This should be defined only
532+
for jobs that don't use an Amazon algorithm.
520533
"""
521534
self.image_name = image_name
522535
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
523536
super(Estimator, self).__init__(role, train_instance_count, train_instance_type,
524537
train_volume_size, train_volume_kms_key, train_max_run, input_mode,
525538
output_path, output_kms_key, base_job_name, sagemaker_session,
526539
tags, subnets, security_group_ids, model_uri=model_uri,
527-
model_channel_name=model_channel_name)
540+
model_channel_name=model_channel_name, metric_definitions=metric_definitions)
528541

529542
def train_image(self):
530543
"""

src/sagemaker/session.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def default_bucket(self):
203203
return self._default_bucket
204204

205205
def train(self, image, input_mode, input_config, role, job_name, output_config,
206-
resource_config, vpc_config, hyperparameters, stop_condition, tags):
206+
resource_config, vpc_config, hyperparameters, stop_condition, tags, metric_definitions):
207207
"""Create an Amazon SageMaker training job.
208208
209209
Args:
@@ -243,6 +243,9 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
243243
service like ``MaxRuntimeInSeconds``.
244244
tags (list[dict]): List of tags for labeling a training job. For more, see
245245
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
246+
metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) used to evaluate the
247+
training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for
248+
the regular expression used to extract the metric from the logs.
246249
247250
Returns:
248251
str: ARN of the training job, if it is created.
@@ -263,6 +266,9 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
263266
if input_config is not None:
264267
train_request['InputDataConfig'] = input_config
265268

269+
if metric_definitions is not None:
270+
train_request['AlgorithmSpecification']['MetricDefinitions'] = metric_definitions
271+
266272
if hyperparameters and len(hyperparameters) > 0:
267273
train_request['HyperParameters'] = hyperparameters
268274

@@ -306,7 +312,7 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
306312
metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) used to evaluate the
307313
training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for
308314
the regular expression used to extract the metric from the logs. This should be defined only for
309-
hyperparameter tuning jobs that don't use an Amazon algorithm.
315+
jobs that don't use an Amazon algorithm.
310316
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs
311317
that create Amazon SageMaker endpoints use this role to access training data and model artifacts.
312318
You must grant sufficient permissions to this role.

tests/unit/test_chainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ def _create_train_job(version):
121121
'MaxRuntimeInSeconds': 24 * 60 * 60
122122
},
123123
'tags': None,
124-
'vpc_config': None
124+
'vpc_config': None,
125+
'metric_definitions': None
125126
}
126127

127128

tests/unit/test_estimator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def test_framework_all_init_args(sagemaker_session):
143143
sagemaker_session=sagemaker_session, train_volume_size=123, train_volume_kms_key='volumekms',
144144
train_max_run=456, input_mode='inputmode', output_path='outputpath', output_kms_key='outputkms',
145145
base_job_name='basejobname', tags=[{'foo': 'bar'}], subnets=['123', '456'],
146-
security_group_ids=['789', '012'])
146+
security_group_ids=['789', '012'],
147+
metric_definitions=[{'Name': 'validation-rmse', 'Regex': 'validation-rmse=(\\d+)'}])
147148
_TrainingJob.start_new(f, 's3://mydata')
148149
sagemaker_session.train.assert_called_once()
149150
_, args = sagemaker_session.train.call_args
@@ -158,7 +159,8 @@ def test_framework_all_init_args(sagemaker_session):
158159
'stop_condition': {'MaxRuntimeInSeconds': 456},
159160
'role': sagemaker_session.expand_role(), 'job_name': None,
160161
'resource_config': {'VolumeSizeInGB': 123, 'InstanceCount': 3, 'VolumeKmsKeyId': 'volumekms',
161-
'InstanceType': 'ml.m4.xlarge'}}
162+
'InstanceType': 'ml.m4.xlarge'},
163+
'metric_definitions': [{'Name': 'validation-rmse', 'Regex': 'validation-rmse=(\\d+)'}]}
162164

163165

164166
def test_sagemaker_s3_uri_invalid(sagemaker_session):
@@ -711,7 +713,8 @@ def test_unsupported_type_in_dict():
711713
},
712714
'stop_condition': {'MaxRuntimeInSeconds': 86400},
713715
'tags': None,
714-
'vpc_config': None
716+
'vpc_config': None,
717+
'metric_definitions': None
715718
}
716719

717720
INPUT_CONFIG = [{

tests/unit/test_mxnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def _create_train_job(version):
9696
'MaxRuntimeInSeconds': 24 * 60 * 60
9797
},
9898
'tags': None,
99-
'vpc_config': None
99+
'vpc_config': None,
100+
'metric_definitions': None
100101
}
101102

102103

tests/unit/test_pytorch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def _create_train_job(version):
112112
'MaxRuntimeInSeconds': 24 * 60 * 60
113113
},
114114
'tags': None,
115-
'vpc_config': None
115+
'vpc_config': None,
116+
'metric_definitions': None
116117
}
117118

118119

tests/unit/test_session.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def test_s3_input_all_arguments():
184184
JOB_NAME = 'jobname'
185185
TAGS = [{'Name': 'some-tag', 'Value': 'value-for-tag'}]
186186
VPC_CONFIG = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']}
187+
METRIC_DEFINITONS = [{'Name': 'validation-rmse', 'Regex': 'validation-rmse=(\\d+)'}]
187188

188189
DEFAULT_EXPECTED_TRAIN_JOB_ARGS = {
189190
'OutputDataConfig': {
@@ -268,7 +269,8 @@ def test_train_pack_to_request(sagemaker_session):
268269

269270
sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
270271
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
271-
hyperparameters=None, stop_condition=stop_cond, tags=None, vpc_config=VPC_CONFIG)
272+
hyperparameters=None, stop_condition=stop_cond, tags=None, vpc_config=VPC_CONFIG,
273+
metric_definitions=None)
272274

273275
assert sagemaker_session.sagemaker_client.method_calls[0] == (
274276
'create_training_job', (), DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
@@ -439,13 +441,15 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
439441

440442
sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
441443
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
442-
vpc_config=VPC_CONFIG, hyperparameters=hyperparameters, stop_condition=stop_cond, tags=TAGS)
444+
vpc_config=VPC_CONFIG, hyperparameters=hyperparameters, stop_condition=stop_cond, tags=TAGS,
445+
metric_definitions=METRIC_DEFINITONS)
443446

444447
_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
445448

446449
assert actual_train_args['VpcConfig'] == VPC_CONFIG
447450
assert actual_train_args['HyperParameters'] == hyperparameters
448451
assert actual_train_args['Tags'] == TAGS
452+
assert actual_train_args['AlgorithmSpecification']['MetricDefinitions'] == METRIC_DEFINITONS
449453

450454

451455
def test_transform_pack_to_request(sagemaker_session):

tests/unit/test_tf_estimator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ def _create_train_job(tf_version, script_mode=False, repo_name=IMAGE_REPO_NAME,
117117
'MaxRuntimeInSeconds': 24 * 60 * 60
118118
},
119119
'tags': None,
120-
'vpc_config': None
120+
'vpc_config': None,
121+
'metric_definitions': None
121122
}
122123

123124

tests/unit/test_tuner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,8 @@ def test_deploy_default(tuner):
413413
returned_training_job_description = {
414414
'AlgorithmSpecification': {
415415
'TrainingInputMode': 'File',
416-
'TrainingImage': IMAGE_NAME
416+
'TrainingImage': IMAGE_NAME,
417+
'MetricDefinitions': METRIC_DEFINTIONS,
417418
},
418419
'HyperParameters': {
419420
'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"',

0 commit comments

Comments
 (0)