Skip to content

Commit a2dfff1

Browse files
danabensDewen Qiqidewenwhen
committed
feature: Add latest metric service model (aws#757)
Co-authored-by: Dewen Qi <[email protected]> Co-authored-by: qidewenwhen <[email protected]>
1 parent 2107d77 commit a2dfff1

File tree

8 files changed

+151
-1273
lines changed

8 files changed

+151
-1273
lines changed

src/sagemaker/experiments/_metrics.py

+35-41
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@
2323

2424
import dateutil.tz
2525

26+
from sagemaker.session import Session
27+
2628
METRICS_DIR = os.environ.get("SAGEMAKER_METRICS_DIRECTORY", ".")
2729
METRIC_TS_LOWER_BOUND_TO_NOW = 1209600 # on seconds
2830
METRIC_TS_UPPER_BOUND_FROM_NOW = 7200 # on seconds
2931

32+
BATCH_SIZE = 10
33+
3034
logging.basicConfig(level=logging.INFO)
3135
logger = logging.getLogger(__name__)
3236

@@ -171,7 +175,7 @@ def to_raw_metric_data(self):
171175
"Timestamp": int(self.Timestamp),
172176
}
173177
if self.Step is not None:
174-
raw_metric_data["IterationNumber"] = int(self.Step)
178+
raw_metric_data["Step"] = int(self.Step)
175179
return raw_metric_data
176180

177181
def __str__(self):
@@ -185,36 +189,27 @@ def __repr__(self):
185189
",".join(["{}={}".format(k, repr(v)) for k, v in vars(self).items()]),
186190
)
187191

188-
def to_request_item(self):
189-
"""Transform a RawMetricData item to a list item for BatchPutMetrics request."""
190-
item = {
191-
"MetricName": self.MetricName,
192-
"Timestamp": int(self.Timestamp),
193-
"Value": self.Value,
194-
}
195-
196-
if self.Step is not None:
197-
item["IterationNumber"] = self.Step
198-
199-
return item
200-
201192

202193
class _MetricsManager(object):
203194
"""Collects metrics and sends them directly to SageMaker Metrics data plane APIs."""
204195

205-
_BATCH_SIZE = 10
206-
207-
def __init__(self, resource_arn, sagemaker_session) -> None:
208-
"""Initiate a `_MetricsManager` instance
196+
def __init__(self, trial_component_name: str, sagemaker_session: Session, sink=None) -> None:
197+
"""Initialize a `_MetricsManager` instance
209198
210199
Args:
211-
resource_arn (str): The ARN of the resource to log metrics to
200+
trial_component_name (str): The Name of the Trial Component to log metrics to
212201
sagemaker_session (sagemaker.session.Session): Session object which
213202
manages interactions with Amazon SageMaker APIs and any other
214203
AWS services needed. If not specified, one is created using the
215204
default AWS configuration chain.
205+
sink (object): The metrics sink to use.
216206
"""
217-
self.sink = _SyncMetricsSink(resource_arn, sagemaker_session.sagemaker_metrics_client)
207+
if sink is None:
208+
self.sink = _SyncMetricsSink(
209+
trial_component_name, sagemaker_session.sagemaker_metrics_client
210+
)
211+
else:
212+
self.sink = sink
218213

219214
def log_metric(self, metric_name, value, timestamp=None, step=None):
220215
"""Sends a metric to metrics service."""
@@ -238,16 +233,14 @@ def close(self):
238233
class _SyncMetricsSink(object):
239234
"""Collects metrics and sends them directly to metrics service."""
240235

241-
_BATCH_SIZE = 10
242-
243-
def __init__(self, resource_arn, metrics_client) -> None:
244-
"""Initiate a `_MetricsManager` instance
236+
def __init__(self, trial_component_name, metrics_client) -> None:
237+
"""Initialize a `_SyncMetricsSink` instance
245238
246239
Args:
247-
resource_arn (str): The ARN of a Trial Component to log metrics.
240+
trial_component_name (str): The Name of the Trial Component to log metrics.
248241
metrics_client (boto3.client): boto client for metrics service
249242
"""
250-
self._resource_arn = resource_arn
243+
self._trial_component_name = trial_component_name
251244
self._metrics_client = metrics_client
252245
self._buffer = []
253246

@@ -265,7 +258,7 @@ def _drain(self, close=False):
265258
if not self._buffer:
266259
return
267260

268-
if len(self._buffer) < self._BATCH_SIZE and not close:
261+
if len(self._buffer) < BATCH_SIZE and not close:
269262
return
270263

271264
# pop all the available metrics
@@ -276,7 +269,10 @@ def _drain(self, close=False):
276269
def _send_metrics(self, metrics):
277270
"""Calls BatchPutMetrics directly on the metrics service."""
278271
while metrics:
279-
batch, metrics = metrics[: self._BATCH_SIZE], metrics[self._BATCH_SIZE :]
272+
batch, metrics = (
273+
metrics[:BATCH_SIZE],
274+
metrics[BATCH_SIZE:],
275+
)
280276
request = self._construct_batch_put_metrics_request(batch)
281277
response = self._metrics_client.batch_put_metrics(**request)
282278
errors = response["Errors"] if "Errors" in response else None
@@ -287,7 +283,7 @@ def _send_metrics(self, metrics):
287283
def _construct_batch_put_metrics_request(self, batch):
288284
"""Creates dictionary object used as request to metrics service."""
289285
return {
290-
"ResourceArn": self._resource_arn,
286+
"TrialComponentName": self._trial_component_name,
291287
"MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)),
292288
}
293289

@@ -300,23 +296,21 @@ class _MetricQueue(object):
300296
"""A thread safe queue for sending metrics to SageMaker.
301297
302298
Args:
303-
resource_arn (str): the ARN of the resource
299+
trial_component_name (str): the ARN of the resource
304300
metric_name (str): the name of the metric
305301
metrics_client (boto_client): the boto client for SageMaker Metrics service
306302
"""
307303

308-
_BATCH_SIZE = 10
309-
310304
_CONSUMER_SLEEP_SECONDS = 5
311305

312-
def __init__(self, resource_arn, metric_name, metrics_client):
306+
def __init__(self, trial_component_name, metric_name, metrics_client):
313307
# infinite queue size
314308
self._queue = queue.Queue()
315309
self._buffer = []
316310
self._thread = threading.Thread(target=self._run)
317311
self._started = False
318312
self._finished = False
319-
self._resource_arn = resource_arn
313+
self._trial_component_name = trial_component_name
320314
self._metrics_client = metrics_client
321315
self._metric_name = metric_name
322316
self._logged_metrics = 0
@@ -325,7 +319,7 @@ def log_metric(self, metric_data):
325319
"""Adds a metric data point to the queue"""
326320
self._buffer.append(metric_data)
327321

328-
if len(self._buffer) < self._BATCH_SIZE:
322+
if len(self._buffer) < BATCH_SIZE:
329323
return
330324

331325
self._enqueue_all()
@@ -354,7 +348,7 @@ def _construct_batch_put_metrics_request(self, batch):
354348
"""Creates dictionary object used as request to metrics service."""
355349

356350
return {
357-
"ResourceArn": self._resource_arn,
351+
"TrialComponentName": self._trial_component_name,
358352
"MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)),
359353
}
360354

@@ -382,14 +376,14 @@ class _AsyncMetricsSink(object):
382376

383377
_COMPLETE_SLEEP_SECONDS = 1.0
384378

385-
def __init__(self, resource_arn, metrics_client) -> None:
386-
"""Initiate a `_MetricsManager` instance
379+
def __init__(self, trial_component_name, metrics_client) -> None:
380+
"""Initialize a `_AsyncMetricsSink` instance
387381
388382
Args:
389-
resource_arn (str): The ARN of a Trial Component to log metrics.
383+
trial_component_name (str): The Name of the Trial Component to log metrics to.
390384
metrics_client (boto3.client): boto client for metrics service
391385
"""
392-
self._resource_arn = resource_arn
386+
self._trial_component_name = trial_component_name
393387
self._metrics_client = metrics_client
394388
self._buffer = []
395389
self._is_draining = False
@@ -402,7 +396,7 @@ def log_metric(self, metric_data):
402396
self._metric_queues[metric_data.MetricName].log_metric(metric_data)
403397
else:
404398
cur_metric_queue = _MetricQueue(
405-
self._resource_arn, metric_data.MetricName, self._metrics_client
399+
self._trial_component_name, metric_data.MetricName, self._metrics_client
406400
)
407401
self._metric_queues[metric_data.MetricName] = cur_metric_queue
408402
cur_metric_queue.log_metric(metric_data)

src/sagemaker/experiments/run.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def __init__(
187187
sagemaker_session=sagemaker_session,
188188
)
189189
self._metrics_manager = _MetricsManager(
190-
resource_arn=self._trial_component.trial_component_arn,
190+
trial_component_name=self._trial_component.trial_component_name,
191191
sagemaker_session=sagemaker_session,
192192
)
193193
self._inside_init_context = False

tests/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
normal_json = "file://./tests/data/experiment/resources/sagemaker-2017-07-24.normal.json"
2222
os.system(f"aws configure add-model --service-model {normal_json} --service-name sagemaker")
2323

24-
metrics_model_json = (
24+
public_metrics_model_json = (
2525
"file://./tests/data/experiment/resources/sagemaker-metrics-2022-09-30.normal.json"
2626
)
2727
os.system(
28-
f"aws configure add-model --service-model {metrics_model_json} --service-name sagemaker-metrics"
28+
f"aws configure add-model --service-model {public_metrics_model_json} --service-name sagemaker-metrics"
2929
)

0 commit comments

Comments
 (0)