23
23
24
24
import dateutil .tz
25
25
26
+ from sagemaker .session import Session
27
+
26
28
METRICS_DIR = os .environ .get ("SAGEMAKER_METRICS_DIRECTORY" , "." )
27
29
METRIC_TS_LOWER_BOUND_TO_NOW = 1209600 # on seconds
28
30
METRIC_TS_UPPER_BOUND_FROM_NOW = 7200 # on seconds
29
31
32
+ BATCH_SIZE = 10
33
+
30
34
logging .basicConfig (level = logging .INFO )
31
35
logger = logging .getLogger (__name__ )
32
36
@@ -171,7 +175,7 @@ def to_raw_metric_data(self):
171
175
"Timestamp" : int (self .Timestamp ),
172
176
}
173
177
if self .Step is not None :
174
- raw_metric_data ["IterationNumber " ] = int (self .Step )
178
+ raw_metric_data ["Step " ] = int (self .Step )
175
179
return raw_metric_data
176
180
177
181
def __str__ (self ):
@@ -185,36 +189,27 @@ def __repr__(self):
185
189
"," .join (["{}={}" .format (k , repr (v )) for k , v in vars (self ).items ()]),
186
190
)
187
191
188
- def to_request_item (self ):
189
- """Transform a RawMetricData item to a list item for BatchPutMetrics request."""
190
- item = {
191
- "MetricName" : self .MetricName ,
192
- "Timestamp" : int (self .Timestamp ),
193
- "Value" : self .Value ,
194
- }
195
-
196
- if self .Step is not None :
197
- item ["IterationNumber" ] = self .Step
198
-
199
- return item
200
-
201
192
202
193
class _MetricsManager (object ):
203
194
"""Collects metrics and sends them directly to SageMaker Metrics data plane APIs."""
204
195
205
- _BATCH_SIZE = 10
206
-
207
- def __init__ (self , resource_arn , sagemaker_session ) -> None :
208
- """Initiate a `_MetricsManager` instance
196
+ def __init__ (self , trial_component_name : str , sagemaker_session : Session , sink = None ) -> None :
197
+ """Initialize a `_MetricsManager` instance
209
198
210
199
Args:
211
- resource_arn (str): The ARN of the resource to log metrics to
200
+ trial_component_name (str): The Name of the Trial Component to log metrics to
212
201
sagemaker_session (sagemaker.session.Session): Session object which
213
202
manages interactions with Amazon SageMaker APIs and any other
214
203
AWS services needed. If not specified, one is created using the
215
204
default AWS configuration chain.
205
+ sink (object): The metrics sink to use.
216
206
"""
217
- self .sink = _SyncMetricsSink (resource_arn , sagemaker_session .sagemaker_metrics_client )
207
+ if sink is None :
208
+ self .sink = _SyncMetricsSink (
209
+ trial_component_name , sagemaker_session .sagemaker_metrics_client
210
+ )
211
+ else :
212
+ self .sink = sink
218
213
219
214
def log_metric (self , metric_name , value , timestamp = None , step = None ):
220
215
"""Sends a metric to metrics service."""
@@ -238,16 +233,14 @@ def close(self):
238
233
class _SyncMetricsSink (object ):
239
234
"""Collects metrics and sends them directly to metrics service."""
240
235
241
- _BATCH_SIZE = 10
242
-
243
- def __init__ (self , resource_arn , metrics_client ) -> None :
244
- """Initiate a `_MetricsManager` instance
236
+ def __init__ (self , trial_component_name , metrics_client ) -> None :
237
+ """Initialize a `_SyncMetricsSink` instance
245
238
246
239
Args:
247
- resource_arn (str): The ARN of a Trial Component to log metrics.
240
+ trial_component_name (str): The Name of the Trial Component to log metrics.
248
241
metrics_client (boto3.client): boto client for metrics service
249
242
"""
250
- self ._resource_arn = resource_arn
243
+ self ._trial_component_name = trial_component_name
251
244
self ._metrics_client = metrics_client
252
245
self ._buffer = []
253
246
@@ -265,7 +258,7 @@ def _drain(self, close=False):
265
258
if not self ._buffer :
266
259
return
267
260
268
- if len (self ._buffer ) < self . _BATCH_SIZE and not close :
261
+ if len (self ._buffer ) < BATCH_SIZE and not close :
269
262
return
270
263
271
264
# pop all the available metrics
@@ -276,7 +269,10 @@ def _drain(self, close=False):
276
269
def _send_metrics (self , metrics ):
277
270
"""Calls BatchPutMetrics directly on the metrics service."""
278
271
while metrics :
279
- batch , metrics = metrics [: self ._BATCH_SIZE ], metrics [self ._BATCH_SIZE :]
272
+ batch , metrics = (
273
+ metrics [:BATCH_SIZE ],
274
+ metrics [BATCH_SIZE :],
275
+ )
280
276
request = self ._construct_batch_put_metrics_request (batch )
281
277
response = self ._metrics_client .batch_put_metrics (** request )
282
278
errors = response ["Errors" ] if "Errors" in response else None
@@ -287,7 +283,7 @@ def _send_metrics(self, metrics):
287
283
def _construct_batch_put_metrics_request (self , batch ):
288
284
"""Creates dictionary object used as request to metrics service."""
289
285
return {
290
- "ResourceArn " : self ._resource_arn ,
286
+ "TrialComponentName " : self ._trial_component_name ,
291
287
"MetricData" : list (map (lambda x : x .to_raw_metric_data (), batch )),
292
288
}
293
289
@@ -300,23 +296,21 @@ class _MetricQueue(object):
300
296
"""A thread safe queue for sending metrics to SageMaker.
301
297
302
298
Args:
303
- resource_arn (str): the ARN of the resource
299
+ trial_component_name (str): the ARN of the resource
304
300
metric_name (str): the name of the metric
305
301
metrics_client (boto_client): the boto client for SageMaker Metrics service
306
302
"""
307
303
308
- _BATCH_SIZE = 10
309
-
310
304
_CONSUMER_SLEEP_SECONDS = 5
311
305
312
- def __init__ (self , resource_arn , metric_name , metrics_client ):
306
+ def __init__ (self , trial_component_name , metric_name , metrics_client ):
313
307
# infinite queue size
314
308
self ._queue = queue .Queue ()
315
309
self ._buffer = []
316
310
self ._thread = threading .Thread (target = self ._run )
317
311
self ._started = False
318
312
self ._finished = False
319
- self ._resource_arn = resource_arn
313
+ self ._trial_component_name = trial_component_name
320
314
self ._metrics_client = metrics_client
321
315
self ._metric_name = metric_name
322
316
self ._logged_metrics = 0
@@ -325,7 +319,7 @@ def log_metric(self, metric_data):
325
319
"""Adds a metric data point to the queue"""
326
320
self ._buffer .append (metric_data )
327
321
328
- if len (self ._buffer ) < self . _BATCH_SIZE :
322
+ if len (self ._buffer ) < BATCH_SIZE :
329
323
return
330
324
331
325
self ._enqueue_all ()
@@ -354,7 +348,7 @@ def _construct_batch_put_metrics_request(self, batch):
354
348
"""Creates dictionary object used as request to metrics service."""
355
349
356
350
return {
357
- "ResourceArn " : self ._resource_arn ,
351
+ "TrialComponentName " : self ._trial_component_name ,
358
352
"MetricData" : list (map (lambda x : x .to_raw_metric_data (), batch )),
359
353
}
360
354
@@ -382,14 +376,14 @@ class _AsyncMetricsSink(object):
382
376
383
377
_COMPLETE_SLEEP_SECONDS = 1.0
384
378
385
- def __init__ (self , resource_arn , metrics_client ) -> None :
386
- """Initiate a `_MetricsManager ` instance
379
+ def __init__ (self , trial_component_name , metrics_client ) -> None :
380
+ """Initialize a `_AsyncMetricsSink ` instance
387
381
388
382
Args:
389
- resource_arn (str): The ARN of a Trial Component to log metrics.
383
+ trial_component_name (str): The Name of the Trial Component to log metrics to .
390
384
metrics_client (boto3.client): boto client for metrics service
391
385
"""
392
- self ._resource_arn = resource_arn
386
+ self ._trial_component_name = trial_component_name
393
387
self ._metrics_client = metrics_client
394
388
self ._buffer = []
395
389
self ._is_draining = False
@@ -402,7 +396,7 @@ def log_metric(self, metric_data):
402
396
self ._metric_queues [metric_data .MetricName ].log_metric (metric_data )
403
397
else :
404
398
cur_metric_queue = _MetricQueue (
405
- self ._resource_arn , metric_data .MetricName , self ._metrics_client
399
+ self ._trial_component_name , metric_data .MetricName , self ._metrics_client
406
400
)
407
401
self ._metric_queues [metric_data .MetricName ] = cur_metric_queue
408
402
cur_metric_queue .log_metric (metric_data )
0 commit comments