18
18
import logging
19
19
import os
20
20
import time
21
+ import threading
22
+ import queue
21
23
22
24
import dateutil .tz
23
25
from botocore .config import Config
24
- from sagemaker .apiutils import _utils
25
26
26
27
27
28
METRICS_DIR = os .environ .get ("SAGEMAKER_METRICS_DIRECTORY" , "." )
@@ -164,6 +165,17 @@ def to_record(self):
164
165
"""Convert the `_RawMetricData` object to dict"""
165
166
return self .__dict__
166
167
168
+ def to_raw_metric_data (self ):
169
+ """Converts the metric data to a BatchPutMetrics RawMetricData item"""
170
+ raw_metric_data = {
171
+ "MetricName" : self .MetricName ,
172
+ "Value" : self .Value ,
173
+ "Timestamp" : int (self .Timestamp ),
174
+ }
175
+ if self .Step is not None :
176
+ raw_metric_data ["IterationNumber" ] = int (self .Step )
177
+ return raw_metric_data
178
+
167
179
def __str__ (self ):
168
180
"""String representation of the `_RawMetricData` object."""
169
181
return repr (self )
@@ -175,66 +187,96 @@ def __repr__(self):
175
187
"," .join (["{}={}" .format (k , repr (v )) for k , v in vars (self ).items ()]),
176
188
)
177
189
190
+ def to_request_item (self ):
191
+ """Transform a RawMetricData item to a list item for BatchPutMetrics request."""
192
+ item = {
193
+ "MetricName" : self .MetricName ,
194
+ "Timestamp" : int (self .Timestamp ),
195
+ "Value" : self .Value ,
196
+ }
178
197
179
- class _MetricsManager ( object ) :
180
- """Collects metrics and sends them directly to metrics service.
198
+ if self . Step is not None :
199
+ item [ "IterationNumber" ] = self . Step
181
200
182
- Note this is a draft implementation for beta and will change significantly prior to launch.
183
- """
201
+ return item
202
+
203
+
204
+ class _MetricsManager (object ):
205
+ """Collects metrics and sends them directly to SageMaker Metrics data plane APIs."""
184
206
185
207
_BATCH_SIZE = 10
186
208
187
- def __init__ (self , resource_arn , sagemaker_session = None ) -> None :
209
+ def __init__ (self , resource_arn , sagemaker_session ) -> None :
188
210
"""Initiate a `_MetricsManager` instance
189
211
190
212
Args:
191
- resource_arn (str): The ARN of a Trial Component to log metrics.
213
+ resource_arn (str): The ARN of the resource to log metrics to
192
214
sagemaker_session (sagemaker.session.Session): Session object which
193
215
manages interactions with Amazon SageMaker APIs and any other
194
216
AWS services needed. If not specified, one is created using the
195
217
default AWS configuration chain.
196
218
"""
197
- self ._resource_arn = resource_arn
198
- self .sagemaker_session = sagemaker_session or _utils .default_session ()
199
- self ._buffer = []
200
- # this client instantiation will need to go into Session
219
+ self ._get_metrics_client (sagemaker_session )
220
+ self .sink = _SyncMetricsSink (resource_arn , self .metrics_client )
221
+
222
+ def log_metric (self , metric_name , value , timestamp = None , step = None ):
223
+ """Sends a metric to metrics service."""
224
+
225
+ metric_data = _RawMetricData (metric_name , value , timestamp , step )
226
+ self .sink .log_metric (metric_data )
227
+
228
+ def __enter__ (self ):
229
+ """Return self"""
230
+ return self
231
+
232
+ def __exit__ (self , exc_type , exc_value , exc_traceback ):
233
+ """Execute self.close()"""
234
+ self .sink .close ()
235
+
236
+ def close (self ):
237
+ """Close the metrics object."""
238
+ self .sink .close ()
239
+
240
+ def _get_metrics_client (self , sagemaker_session ):
241
+ """Return self"""
242
+
243
+ # TODO move this client instantiation into Session
201
244
config = Config (retries = {"max_attempts" : 10 , "mode" : "adaptive" })
202
245
stage = "prod"
203
- region = self . sagemaker_session .boto_session .region_name
246
+ region = sagemaker_session .boto_session .region_name
204
247
endpoint = f"https://training-metrics.{ stage } .{ region } .ml-platform.aws.a2z.com"
205
- self .metrics_service_client = self . sagemaker_session .boto_session .client (
248
+ self .metrics_client = sagemaker_session .boto_session .client (
206
249
"sagemaker-metrics" , config = config , endpoint_url = endpoint
207
250
)
208
251
209
- def log_metric (self , metric_name , value , timestamp = None , step = None ):
210
- """Sends a metric to metrics service.
252
+
253
+ class _SyncMetricsSink (object ):
254
+ """Collects metrics and sends them directly to metrics service."""
255
+
256
+ _BATCH_SIZE = 10
257
+
258
+ def __init__ (self , resource_arn , metrics_client ) -> None :
259
+ """Initiate a `_MetricsManager` instance
211
260
212
261
Args:
213
- metric_name (str): The name of the metric.
214
- value (float): The value of the metric.
215
- timestamp (datetime.datetime): Timestamp of the metric.
216
- If not specified, the current UTC time will be used.
217
- step (int): Iteration number of the metric (default: None).
262
+ resource_arn (str): The ARN of a Trial Component to log metrics.
263
+ metrics_client (boto3.client): boto client for metrics service
218
264
"""
219
- metric_data = _RawMetricData (
220
- metric_name = metric_name , value = value , timestamp = timestamp , step = step
221
- )
265
+ self ._resource_arn = resource_arn
266
+ self ._metrics_client = metrics_client
267
+ self ._buffer = []
268
+
269
+ def log_metric (self , metric_data ):
270
+ """Sends a metric to metrics service."""
271
+
222
272
# this is a simplistic solution which calls BatchPutMetrics
223
273
# on the same thread as the client code
224
274
self ._buffer .append (metric_data )
225
275
self ._drain ()
226
276
227
277
def _drain (self , close = False ):
228
- """Pops off all metrics in the buffer and starts sending them to metrics service.
278
+ """Pops off all metrics in the buffer and starts sending them to metrics service."""
229
279
230
- Args:
231
- close (bool): Indicates if this method is invoked within the `close` function
232
- (default: False). If invoked in the `close` function, the remaining logged
233
- metrics in the buffer will be all sent out to the Metrics Service.
234
- Otherwise, the metrics will only be sent out if the number of them reaches the
235
- batch size.
236
- """
237
- # no metrics to send
238
280
if not self ._buffer :
239
281
return
240
282
@@ -247,60 +289,146 @@ def _drain(self, close=False):
247
289
self ._send_metrics (available_metrics )
248
290
249
291
def _send_metrics (self , metrics ):
250
- """Calls BatchPutMetrics directly on the metrics service.
251
-
252
- Args:
253
- metrics (list[_RawMetricData]): A list of `_RawMetricData` objects.
254
- """
292
+ """Calls BatchPutMetrics directly on the metrics service."""
255
293
while metrics :
256
294
batch , metrics = metrics [: self ._BATCH_SIZE ], metrics [self ._BATCH_SIZE :]
257
295
request = self ._construct_batch_put_metrics_request (batch )
258
- self .metrics_service_client .batch_put_metrics (** request )
296
+ response = self ._metrics_client .batch_put_metrics (** request )
297
+ errors = response ["Errors" ] if "Errors" in response else None
298
+ if errors :
299
+ message = errors [0 ]["Message" ]
300
+ raise Exception (f'{ len (errors )} errors with message "{ message } "' )
259
301
260
302
def _construct_batch_put_metrics_request (self , batch ):
261
- """Creates dictionary object used as request to metrics service.
262
-
263
- Args:
264
- batch (list[_RawMetricData]): A list of `_RawMetricData` objects,
265
- whose length is within the batch size limitation.
266
- """
303
+ """Creates dictionary object used as request to metrics service."""
267
304
return {
268
305
"ResourceArn" : self ._resource_arn ,
269
- "MetricData" : list (map (self . _to_raw_metric_data , batch )),
306
+ "MetricData" : list (map (lambda x : x . to_raw_metric_data () , batch )),
270
307
}
271
308
272
- @ staticmethod
273
- def _to_raw_metric_data ( metric_data ):
274
- """Transform a RawMetricData item to a list item for BatchPutMetrics request.
309
+ def close ( self ):
310
+ """Drains any remaining metrics."""
311
+ self . _drain ( close = True )
275
312
276
- Args:
277
- metric_data (_RawMetricData): The `_RawMetricData` object to be transformed.
278
- """
279
- item = {
280
- "MetricName" : metric_data .MetricName ,
281
- "Timestamp" : int (metric_data .Timestamp ),
282
- "Value" : metric_data .Value ,
313
+
314
+ class _MetricQueue (object ):
315
+ """A thread safe queue for sending metrics to SageMaker.
316
+
317
+ Args:
318
+ resource_arn (str): the ARN of the resource
319
+ metric_name (str): the name of the metric
320
+ metrics_client (boto_client): the boto client for SageMaker Metrics service
321
+ """
322
+
323
+ _BATCH_SIZE = 10
324
+
325
+ _CONSUMER_SLEEP_SECONDS = 5
326
+
327
+ def __init__ (self , resource_arn , metric_name , metrics_client ):
328
+ # infinite queue size
329
+ self ._queue = queue .Queue ()
330
+ self ._buffer = []
331
+ self ._thread = threading .Thread (target = self ._run )
332
+ self ._started = False
333
+ self ._finished = False
334
+ self ._resource_arn = resource_arn
335
+ self ._metrics_client = metrics_client
336
+ self ._metric_name = metric_name
337
+ self ._logged_metrics = 0
338
+
339
+ def log_metric (self , metric_data ):
340
+ """Adds a metric data point to the queue"""
341
+ self ._buffer .append (metric_data )
342
+
343
+ if len (self ._buffer ) < self ._BATCH_SIZE :
344
+ return
345
+
346
+ self ._enqueue_all ()
347
+
348
+ if not self ._started :
349
+ self ._thread .start ()
350
+ self ._started = True
351
+
352
+ def _run (self ):
353
+ """Starts the metric thread which sends metrics to SageMaker in batches"""
354
+
355
+ while not self ._queue .empty () or not self ._finished :
356
+ if self ._queue .empty ():
357
+ time .sleep (self ._CONSUMER_SLEEP_SECONDS )
358
+ else :
359
+ batch = self ._queue .get ()
360
+ self ._send_metrics (batch )
361
+
362
+ def _send_metrics (self , metrics_batch ):
363
+ """Calls BatchPutMetrics directly on the metrics service."""
364
+ request = self ._construct_batch_put_metrics_request (metrics_batch )
365
+ self ._logged_metrics += len (metrics_batch )
366
+ self ._metrics_client .batch_put_metrics (** request )
367
+
368
+ def _construct_batch_put_metrics_request (self , batch ):
369
+ """Creates dictionary object used as request to metrics service."""
370
+
371
+ return {
372
+ "ResourceArn" : self ._resource_arn ,
373
+ "MetricData" : list (map (lambda x : x .to_raw_metric_data (), batch )),
283
374
}
284
375
285
- if metric_data . Step is not None :
286
- item [ "IterationNumber" ] = metric_data . Step
376
+ def _enqueue_all ( self ) :
377
+ """Enqueue all buffered metrics to be sent to SageMaker"""
287
378
288
- return item
379
+ available_metrics , self ._buffer = self ._buffer , []
380
+ if available_metrics :
381
+ self ._queue .put (available_metrics )
289
382
290
383
def close (self ):
291
- """Drain the metrics buffer and send metrics to Metrics Service."""
292
- self ._drain (close = True )
384
+ """Flushes any buffered metrics"""
293
385
294
- def __enter__ (self ):
295
- """Return self"""
296
- return self
386
+ self ._enqueue_all ()
387
+ self ._finished = True
297
388
298
- def __exit__ (self , exc_type , exc_value , exc_traceback ):
299
- """Execute self.close() to send out metrics in the buffer.
389
+ def is_active (self ):
390
+ """Is the thread active (still draining metrics to SageMaker)"""
391
+
392
+ return self ._thread .is_alive ()
393
+
394
+
395
+ class _AsyncMetricsSink (object ):
396
+ """Collects metrics and sends them directly to metrics service."""
397
+
398
+ _COMPLETE_SLEEP_SECONDS = 1.0
399
+
400
+ def __init__ (self , resource_arn , metrics_client ) -> None :
401
+ """Initiate a `_MetricsManager` instance
300
402
301
403
Args:
302
- exc_type (str): The exception type.
303
- exc_value (str): The exception value.
304
- exc_traceback (str): The stack trace of the exception.
404
+ resource_arn (str): The ARN of a Trial Component to log metrics.
405
+ metrics_client (boto3.client): boto client for metrics service
305
406
"""
306
- self .close ()
407
+ self ._resource_arn = resource_arn
408
+ self ._metrics_client = metrics_client
409
+ self ._buffer = []
410
+ self ._is_draining = False
411
+ self ._metric_queues = {}
412
+
413
+ def log_metric (self , metric_data ):
414
+ """Sends a metric to metrics service."""
415
+
416
+ if metric_data .MetricName in self ._metric_queues :
417
+ self ._metric_queues [metric_data .MetricName ].log_metric (metric_data )
418
+ else :
419
+ cur_metric_queue = _MetricQueue (
420
+ self ._resource_arn , metric_data .MetricName , self ._metrics_client
421
+ )
422
+ self ._metric_queues [metric_data .MetricName ] = cur_metric_queue
423
+ cur_metric_queue .log_metric (metric_data )
424
+
425
+ def close (self ):
426
+ """Closes the metric file."""
427
+ logging .debug ("Closing" )
428
+ for q in self ._metric_queues .values ():
429
+ q .close ()
430
+
431
+ # TODO should probably use join
432
+ while any (map (lambda x : x .is_active (), self ._metric_queues .values ())):
433
+ time .sleep (self ._COMPLETE_SLEEP_SECONDS )
434
+ logging .debug ("Closed" )
0 commit comments