Skip to content

Commit bca16b1

Browse files
danabensDewen Qiqidewenwhen
committed
add metrics client to session object (aws#745)
Co-authored-by: Dewen Qi <[email protected]> Co-authored-by: Dana Benson <[email protected]> Co-authored-by: Dana Benson <[email protected]> Co-authored-by: qidewenwhen <[email protected]>
1 parent be49268 commit bca16b1

11 files changed

+35290
-179
lines changed

src/sagemaker/session.py

+13
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def __init__(
8989
sagemaker_featurestore_runtime_client=None,
9090
default_bucket=None,
9191
settings=SessionSettings(),
92+
sagemaker_metrics_client=None,
9293
):
9394
"""Initialize a SageMaker ``Session``.
9495
@@ -116,6 +117,10 @@ def __init__(
116117
Example: "sagemaker-my-custom-bucket".
117118
settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional
118119
parameters to apply to the session.
120+
sagemaker_metrics_client (boto3.SageMakerMetrics.Client):
121+
Client which makes SageMaker Metrics related calls to Amazon SageMaker
122+
(default: None). If not provided, one will be created using
123+
this instance's ``boto_session``.
119124
"""
120125
self._default_bucket = None
121126
self._default_bucket_name_override = default_bucket
@@ -130,6 +135,7 @@ def __init__(
130135
sagemaker_client=sagemaker_client,
131136
sagemaker_runtime_client=sagemaker_runtime_client,
132137
sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client,
138+
sagemaker_metrics_client=sagemaker_metrics_client,
133139
)
134140

135141
def _initialize(
@@ -138,6 +144,7 @@ def _initialize(
138144
sagemaker_client,
139145
sagemaker_runtime_client,
140146
sagemaker_featurestore_runtime_client,
147+
sagemaker_metrics_client,
141148
):
142149
"""Initialize this SageMaker Session.
143150
@@ -172,6 +179,12 @@ def _initialize(
172179
"sagemaker-featurestore-runtime"
173180
)
174181

182+
if sagemaker_metrics_client:
183+
self.sagemaker_metrics_client = sagemaker_metrics_client
184+
else:
185+
self.sagemaker_metrics_client = self.boto_session.client("sagemaker-metrics")
186+
prepend_user_agent(self.sagemaker_metrics_client)
187+
175188
self.local_mode = False
176189

177190
@property
+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
FROM ubuntu:22.04
15+
16+
ARG script
17+
ARG library
18+
ARG botomodel
19+
ARG metricsmodel
20+
21+
RUN apt-get update && apt-get install -y --no-install-recommends \
22+
python3 \
23+
curl \
24+
ca-certificates \
25+
awscli \
26+
&& \
27+
apt-get clean && \
28+
rm -rf /var/lib/apt/lists/*
29+
30+
RUN curl -fSsL -O https://bootstrap.pypa.io/get-pip.py && \
31+
python3 get-pip.py && \
32+
rm get-pip.py
33+
34+
35+
WORKDIR /root
36+
37+
COPY $library .
38+
39+
# use a custom model
40+
# TODO: coumment out these four lines once the new API changes for Run release
41+
COPY $botomodel sagemaker-2017-07-24.normal.json
42+
RUN aws configure add-model --service-name sagemaker --service-model file://sagemaker-2017-07-24.normal.json
43+
COPY $metricsmodel sagemaker-metrics-2022-09-30.normal.json
44+
RUN aws configure add-model --service-name sagemaker-metrics --service-model file://sagemaker-metrics-2022-09-30.normal.json
45+
46+
RUN python3 -m pip install $(basename $library)
47+
48+
COPY $script script.py
49+
50+
ENTRYPOINT ["python3", "./script.py"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This script file runs on SageMaker training job"""
14+
from __future__ import absolute_import
15+
16+
import logging
17+
import time
18+
19+
import os
20+
21+
import boto3
22+
23+
from sagemaker import Session
24+
from sagemaker.experiments.run import Run
25+
26+
for key, value in os.environ.items():
27+
logging.info("OS env var - {}: {}".format(key, value))
28+
29+
boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
30+
sagemaker_session = Session(boto_session=boto_session)
31+
32+
with Run.init(
33+
experiment_name="my-train-job-exp-in-script",
34+
run_name="my-train-job-run-in-script",
35+
sagemaker_session=sagemaker_session,
36+
) as run:
37+
logging.info(f"Run name: {run.run_name}")
38+
logging.info(f"Experiment name: {run.experiment_name}")
39+
logging.info(f"Trial component name: {run._trial_component.trial_component_name}")
40+
run.log_parameter("p1", 1.0)
41+
run.log_parameter("p2", 2.0)
42+
if "TRAINING_JOB_ARN" in os.environ:
43+
for i in range(2):
44+
run.log_metric("A", i)
45+
for i in range(2):
46+
run.log_metric("B", i)
47+
for i in range(2):
48+
run.log_metric("C", i)
49+
for i in range(2):
50+
time.sleep(0.003)
51+
run.log_metric("D", i)
52+
for i in range(2):
53+
time.sleep(0.003)
54+
run.log_metric("E", i)
55+
time.sleep(15)

0 commit comments

Comments
 (0)