Skip to content

Commit 25299c4

Browse files
danabensDewen Qiqidewenwhen
committed
Add async metrics sink (aws#739)
Co-authored-by: Dewen Qi <[email protected]> Co-authored-by: Dana Benson <[email protected]> Co-authored-by: Dana Benson <[email protected]> Co-authored-by: qidewenwhen <[email protected]>
1 parent 620f9b6 commit 25299c4

File tree

5 files changed

+225
-76
lines changed

5 files changed

+225
-76
lines changed

src/sagemaker/experiments/metrics.py renamed to src/sagemaker/experiments/_metrics.py

+197-69
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@
1818
import logging
1919
import os
2020
import time
21+
import threading
22+
import queue
2123

2224
import dateutil.tz
2325
from botocore.config import Config
24-
from sagemaker.apiutils import _utils
2526

2627

2728
METRICS_DIR = os.environ.get("SAGEMAKER_METRICS_DIRECTORY", ".")
@@ -164,6 +165,17 @@ def to_record(self):
164165
"""Convert the `_RawMetricData` object to dict"""
165166
return self.__dict__
166167

168+
def to_raw_metric_data(self):
169+
"""Converts the metric data to a BatchPutMetrics RawMetricData item"""
170+
raw_metric_data = {
171+
"MetricName": self.MetricName,
172+
"Value": self.Value,
173+
"Timestamp": int(self.Timestamp),
174+
}
175+
if self.Step is not None:
176+
raw_metric_data["IterationNumber"] = int(self.Step)
177+
return raw_metric_data
178+
167179
def __str__(self):
168180
"""String representation of the `_RawMetricData` object."""
169181
return repr(self)
@@ -175,66 +187,96 @@ def __repr__(self):
175187
",".join(["{}={}".format(k, repr(v)) for k, v in vars(self).items()]),
176188
)
177189

190+
def to_request_item(self):
191+
"""Transform a RawMetricData item to a list item for BatchPutMetrics request."""
192+
item = {
193+
"MetricName": self.MetricName,
194+
"Timestamp": int(self.Timestamp),
195+
"Value": self.Value,
196+
}
178197

179-
class _MetricsManager(object):
180-
"""Collects metrics and sends them directly to metrics service.
198+
if self.Step is not None:
199+
item["IterationNumber"] = self.Step
181200

182-
Note this is a draft implementation for beta and will change significantly prior to launch.
183-
"""
201+
return item
202+
203+
204+
class _MetricsManager(object):
205+
"""Collects metrics and sends them directly to SageMaker Metrics data plane APIs."""
184206

185207
_BATCH_SIZE = 10
186208

187-
def __init__(self, resource_arn, sagemaker_session=None) -> None:
209+
def __init__(self, resource_arn, sagemaker_session) -> None:
188210
"""Initiate a `_MetricsManager` instance
189211
190212
Args:
191-
resource_arn (str): The ARN of a Trial Component to log metrics.
213+
resource_arn (str): The ARN of the resource to log metrics to
192214
sagemaker_session (sagemaker.session.Session): Session object which
193215
manages interactions with Amazon SageMaker APIs and any other
194216
AWS services needed. If not specified, one is created using the
195217
default AWS configuration chain.
196218
"""
197-
self._resource_arn = resource_arn
198-
self.sagemaker_session = sagemaker_session or _utils.default_session()
199-
self._buffer = []
200-
# this client instantiation will need to go into Session
219+
self._get_metrics_client(sagemaker_session)
220+
self.sink = _SyncMetricsSink(resource_arn, self.metrics_client)
221+
222+
def log_metric(self, metric_name, value, timestamp=None, step=None):
223+
"""Sends a metric to metrics service."""
224+
225+
metric_data = _RawMetricData(metric_name, value, timestamp, step)
226+
self.sink.log_metric(metric_data)
227+
228+
def __enter__(self):
229+
"""Return self"""
230+
return self
231+
232+
def __exit__(self, exc_type, exc_value, exc_traceback):
233+
"""Execute self.close()"""
234+
self.sink.close()
235+
236+
def close(self):
237+
"""Close the metrics object."""
238+
self.sink.close()
239+
240+
def _get_metrics_client(self, sagemaker_session):
241+
"""Return self"""
242+
243+
# TODO move this client instantiation into Session
201244
config = Config(retries={"max_attempts": 10, "mode": "adaptive"})
202245
stage = "prod"
203-
region = self.sagemaker_session.boto_session.region_name
246+
region = sagemaker_session.boto_session.region_name
204247
endpoint = f"https://training-metrics.{stage}.{region}.ml-platform.aws.a2z.com"
205-
self.metrics_service_client = self.sagemaker_session.boto_session.client(
248+
self.metrics_client = sagemaker_session.boto_session.client(
206249
"sagemaker-metrics", config=config, endpoint_url=endpoint
207250
)
208251

209-
def log_metric(self, metric_name, value, timestamp=None, step=None):
210-
"""Sends a metric to metrics service.
252+
253+
class _SyncMetricsSink(object):
254+
"""Collects metrics and sends them directly to metrics service."""
255+
256+
_BATCH_SIZE = 10
257+
258+
def __init__(self, resource_arn, metrics_client) -> None:
259+
"""Initiate a `_MetricsManager` instance
211260
212261
Args:
213-
metric_name (str): The name of the metric.
214-
value (float): The value of the metric.
215-
timestamp (datetime.datetime): Timestamp of the metric.
216-
If not specified, the current UTC time will be used.
217-
step (int): Iteration number of the metric (default: None).
262+
resource_arn (str): The ARN of a Trial Component to log metrics.
263+
metrics_client (boto3.client): boto client for metrics service
218264
"""
219-
metric_data = _RawMetricData(
220-
metric_name=metric_name, value=value, timestamp=timestamp, step=step
221-
)
265+
self._resource_arn = resource_arn
266+
self._metrics_client = metrics_client
267+
self._buffer = []
268+
269+
def log_metric(self, metric_data):
270+
"""Sends a metric to metrics service."""
271+
222272
# this is a simplistic solution which calls BatchPutMetrics
223273
# on the same thread as the client code
224274
self._buffer.append(metric_data)
225275
self._drain()
226276

227277
def _drain(self, close=False):
228-
"""Pops off all metrics in the buffer and starts sending them to metrics service.
278+
"""Pops off all metrics in the buffer and starts sending them to metrics service."""
229279

230-
Args:
231-
close (bool): Indicates if this method is invoked within the `close` function
232-
(default: False). If invoked in the `close` function, the remaining logged
233-
metrics in the buffer will be all sent out to the Metrics Service.
234-
Otherwise, the metrics will only be sent out if the number of them reaches the
235-
batch size.
236-
"""
237-
# no metrics to send
238280
if not self._buffer:
239281
return
240282

@@ -247,60 +289,146 @@ def _drain(self, close=False):
247289
self._send_metrics(available_metrics)
248290

249291
def _send_metrics(self, metrics):
250-
"""Calls BatchPutMetrics directly on the metrics service.
251-
252-
Args:
253-
metrics (list[_RawMetricData]): A list of `_RawMetricData` objects.
254-
"""
292+
"""Calls BatchPutMetrics directly on the metrics service."""
255293
while metrics:
256294
batch, metrics = metrics[: self._BATCH_SIZE], metrics[self._BATCH_SIZE :]
257295
request = self._construct_batch_put_metrics_request(batch)
258-
self.metrics_service_client.batch_put_metrics(**request)
296+
response = self._metrics_client.batch_put_metrics(**request)
297+
errors = response["Errors"] if "Errors" in response else None
298+
if errors:
299+
message = errors[0]["Message"]
300+
raise Exception(f'{len(errors)} errors with message "{message}"')
259301

260302
def _construct_batch_put_metrics_request(self, batch):
261-
"""Creates dictionary object used as request to metrics service.
262-
263-
Args:
264-
batch (list[_RawMetricData]): A list of `_RawMetricData` objects,
265-
whose length is within the batch size limitation.
266-
"""
303+
"""Creates dictionary object used as request to metrics service."""
267304
return {
268305
"ResourceArn": self._resource_arn,
269-
"MetricData": list(map(self._to_raw_metric_data, batch)),
306+
"MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)),
270307
}
271308

272-
@staticmethod
273-
def _to_raw_metric_data(metric_data):
274-
"""Transform a RawMetricData item to a list item for BatchPutMetrics request.
309+
def close(self):
310+
"""Drains any remaining metrics."""
311+
self._drain(close=True)
275312

276-
Args:
277-
metric_data (_RawMetricData): The `_RawMetricData` object to be transformed.
278-
"""
279-
item = {
280-
"MetricName": metric_data.MetricName,
281-
"Timestamp": int(metric_data.Timestamp),
282-
"Value": metric_data.Value,
313+
314+
class _MetricQueue(object):
315+
"""A thread safe queue for sending metrics to SageMaker.
316+
317+
Args:
318+
resource_arn (str): the ARN of the resource
319+
metric_name (str): the name of the metric
320+
metrics_client (boto_client): the boto client for SageMaker Metrics service
321+
"""
322+
323+
_BATCH_SIZE = 10
324+
325+
_CONSUMER_SLEEP_SECONDS = 5
326+
327+
def __init__(self, resource_arn, metric_name, metrics_client):
328+
# infinite queue size
329+
self._queue = queue.Queue()
330+
self._buffer = []
331+
self._thread = threading.Thread(target=self._run)
332+
self._started = False
333+
self._finished = False
334+
self._resource_arn = resource_arn
335+
self._metrics_client = metrics_client
336+
self._metric_name = metric_name
337+
self._logged_metrics = 0
338+
339+
def log_metric(self, metric_data):
340+
"""Adds a metric data point to the queue"""
341+
self._buffer.append(metric_data)
342+
343+
if len(self._buffer) < self._BATCH_SIZE:
344+
return
345+
346+
self._enqueue_all()
347+
348+
if not self._started:
349+
self._thread.start()
350+
self._started = True
351+
352+
def _run(self):
353+
"""Starts the metric thread which sends metrics to SageMaker in batches"""
354+
355+
while not self._queue.empty() or not self._finished:
356+
if self._queue.empty():
357+
time.sleep(self._CONSUMER_SLEEP_SECONDS)
358+
else:
359+
batch = self._queue.get()
360+
self._send_metrics(batch)
361+
362+
def _send_metrics(self, metrics_batch):
363+
"""Calls BatchPutMetrics directly on the metrics service."""
364+
request = self._construct_batch_put_metrics_request(metrics_batch)
365+
self._logged_metrics += len(metrics_batch)
366+
self._metrics_client.batch_put_metrics(**request)
367+
368+
def _construct_batch_put_metrics_request(self, batch):
369+
"""Creates dictionary object used as request to metrics service."""
370+
371+
return {
372+
"ResourceArn": self._resource_arn,
373+
"MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)),
283374
}
284375

285-
if metric_data.Step is not None:
286-
item["IterationNumber"] = metric_data.Step
376+
def _enqueue_all(self):
377+
"""Enqueue all buffered metrics to be sent to SageMaker"""
287378

288-
return item
379+
available_metrics, self._buffer = self._buffer, []
380+
if available_metrics:
381+
self._queue.put(available_metrics)
289382

290383
def close(self):
291-
"""Drain the metrics buffer and send metrics to Metrics Service."""
292-
self._drain(close=True)
384+
"""Flushes any buffered metrics"""
293385

294-
def __enter__(self):
295-
"""Return self"""
296-
return self
386+
self._enqueue_all()
387+
self._finished = True
297388

298-
def __exit__(self, exc_type, exc_value, exc_traceback):
299-
"""Execute self.close() to send out metrics in the buffer.
389+
def is_active(self):
390+
"""Is the thread active (still draining metrics to SageMaker)"""
391+
392+
return self._thread.is_alive()
393+
394+
395+
class _AsyncMetricsSink(object):
396+
"""Collects metrics and sends them directly to metrics service."""
397+
398+
_COMPLETE_SLEEP_SECONDS = 1.0
399+
400+
def __init__(self, resource_arn, metrics_client) -> None:
401+
"""Initiate a `_MetricsManager` instance
300402
301403
Args:
302-
exc_type (str): The exception type.
303-
exc_value (str): The exception value.
304-
exc_traceback (str): The stack trace of the exception.
404+
resource_arn (str): The ARN of a Trial Component to log metrics.
405+
metrics_client (boto3.client): boto client for metrics service
305406
"""
306-
self.close()
407+
self._resource_arn = resource_arn
408+
self._metrics_client = metrics_client
409+
self._buffer = []
410+
self._is_draining = False
411+
self._metric_queues = {}
412+
413+
def log_metric(self, metric_data):
414+
"""Sends a metric to metrics service."""
415+
416+
if metric_data.MetricName in self._metric_queues:
417+
self._metric_queues[metric_data.MetricName].log_metric(metric_data)
418+
else:
419+
cur_metric_queue = _MetricQueue(
420+
self._resource_arn, metric_data.MetricName, self._metrics_client
421+
)
422+
self._metric_queues[metric_data.MetricName] = cur_metric_queue
423+
cur_metric_queue.log_metric(metric_data)
424+
425+
def close(self):
426+
"""Closes the metric file."""
427+
logging.debug("Closing")
428+
for q in self._metric_queues.values():
429+
q.close()
430+
431+
# TODO should probably use join
432+
while any(map(lambda x: x.is_active(), self._metric_queues.values())):
433+
time.sleep(self._COMPLETE_SLEEP_SECONDS)
434+
logging.debug("Closed")

src/sagemaker/experiments/run.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from sagemaker.experiments._environment import _RunEnvironment, _EnvironmentType
3434
from sagemaker.experiments._run_context import _RunContext
3535
from sagemaker.experiments.experiment import _Experiment
36-
from sagemaker.experiments.metrics import _MetricsManager
36+
from sagemaker.experiments._metrics import _MetricsManager
3737
from sagemaker.experiments.trial import _Trial
3838
from sagemaker.experiments.trial_component import _TrialComponent
3939

tests/integ/sagemaker/experiments/test_metrics.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414
import random
15-
from sagemaker.experiments.metrics import _MetricsManager
15+
from sagemaker.experiments._metrics import _MetricsManager
1616
from sagemaker.experiments.trial_component import _TrialComponent
1717

1818

@@ -31,7 +31,7 @@ def test_epoch(trial_component_obj, sagemaker_session):
3131
assert updated_tc.metrics[0].metric_name == metric_name
3232

3333

34-
def test_no_epoch(trial_component_obj, sagemaker_session):
34+
def test_timestamp(trial_component_obj, sagemaker_session):
3535
# The fixture creates deletes, just ensure fixture is used at least once
3636
metric_name = "test-x-timestamp"
3737
with _MetricsManager(trial_component_obj.trial_component_arn, sagemaker_session) as mm:
@@ -42,5 +42,7 @@ def test_no_epoch(trial_component_obj, sagemaker_session):
4242
trial_component_name=trial_component_obj.trial_component_name,
4343
sagemaker_session=sagemaker_session,
4444
)
45-
assert len(updated_tc.metrics) == 1
46-
assert updated_tc.metrics[0].metric_name == metric_name
45+
# the test-x-step data is added in the previous test_epoch test
46+
assert len(updated_tc.metrics) == 2
47+
assert updated_tc.metrics[0].metric_name == "test-x-step"
48+
assert updated_tc.metrics[1].metric_name == "test-x-timestamp"

tests/integ/sagemaker/experiments/test_run.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sagemaker.s3 import S3Uploader
2525
from sagemaker.xgboost import XGBoostModel
2626
from tests.integ import DATA_DIR
27-
from sagemaker.experiments.metrics import _MetricsManager
27+
from sagemaker.experiments._metrics import _MetricsManager
2828
from sagemaker.experiments.trial_component import _TrialComponent
2929
from sagemaker.sklearn import SKLearn
3030
from sagemaker.utils import retry_with_backoff, unique_name_from_base

0 commit comments

Comments
 (0)