Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2afc516

Browse files
committedJun 11, 2024·
Implement custom telemetry logging in SDK
1 parent ed43b07 commit 2afc516

File tree

6 files changed

+606
-17
lines changed

6 files changed

+606
-17
lines changed
 

‎README.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,13 @@ SageMaker Python SDK is tested on:
9595
- Python 3.10
9696
- Python 3.11
9797

98+
Telemetry
99+
~~~~~~~~~~~~~~~
100+
101+
The ``sagemaker`` library has telemetry enabled to help us better understand user needs, diagnose issues, and deliver new features. This telemetry tracks the usage of various SageMaker functions.
102+
103+
If you prefer to opt out of telemetry, you can easily do so by setting the ``TelemetryOptOut`` parameter to ``true`` in the SDK defaults configuration. For detailed instructions, please visit `Configuring and using defaults with the SageMaker Python SDK <https://sagemaker.readthedocs.io/en/stable/overview.html#configuring-and-using-defaults-with-the-sagemaker-python-sdk>`__.
104+
98105
AWS Permissions
99106
~~~~~~~~~~~~~~~
100107

‎src/sagemaker/local/local_session.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
_LocalPipeline,
4343
)
4444
from sagemaker.session import Session
45+
from sagemaker.telemetry.telemetry_logging import _telemetry_emitter
46+
from sagemaker.telemetry.constants import Feature
4547
from sagemaker.utils import (
4648
get_config_value,
4749
_module_import_error,
@@ -83,6 +85,7 @@ def __init__(self, sagemaker_session=None):
8385
"""
8486
self.sagemaker_session = sagemaker_session or LocalSession()
8587

88+
@_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_processing_job")
8689
def create_processing_job(
8790
self,
8891
ProcessingJobName,
@@ -165,6 +168,7 @@ def describe_processing_job(self, ProcessingJobName):
165168
raise ClientError(error_response, "describe_processing_job")
166169
return LocalSagemakerClient._processing_jobs[ProcessingJobName].describe()
167170

171+
@_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_training_job")
168172
def create_training_job(
169173
self,
170174
TrainingJobName,
@@ -235,6 +239,7 @@ def describe_training_job(self, TrainingJobName):
235239
raise ClientError(error_response, "describe_training_job")
236240
return LocalSagemakerClient._training_jobs[TrainingJobName].describe()
237241

242+
@_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_transform_job")
238243
def create_transform_job(
239244
self,
240245
TransformJobName,
@@ -280,6 +285,7 @@ def describe_transform_job(self, TransformJobName):
280285
raise ClientError(error_response, "describe_transform_job")
281286
return LocalSagemakerClient._transform_jobs[TransformJobName].describe()
282287

288+
@_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_model")
283289
def create_model(
284290
self, ModelName, PrimaryContainer, *args, **kwargs
285291
): # pylint: disable=unused-argument
@@ -329,6 +335,7 @@ def describe_endpoint_config(self, EndpointConfigName):
329335
raise ClientError(error_response, "describe_endpoint_config")
330336
return LocalSagemakerClient._endpoint_configs[EndpointConfigName].describe()
331337

338+
@_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_endpoint_config")
332339
def create_endpoint_config(self, EndpointConfigName, ProductionVariants, Tags=None):
333340
"""Create the endpoint configuration.
334341
@@ -360,6 +367,7 @@ def describe_endpoint(self, EndpointName):
360367
raise ClientError(error_response, "describe_endpoint")
361368
return LocalSagemakerClient._endpoints[EndpointName].describe()
362369

370+
@_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_endpoint")
363371
def create_endpoint(self, EndpointName, EndpointConfigName, Tags=None):
364372
"""Create the endpoint.
365373
@@ -428,6 +436,7 @@ def delete_model(self, ModelName):
428436
if ModelName in LocalSagemakerClient._models:
429437
del LocalSagemakerClient._models[ModelName]
430438

439+
@_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_pipeline")
431440
def create_pipeline(
432441
self, pipeline, pipeline_description, **kwargs # pylint: disable=unused-argument
433442
):

‎src/sagemaker/telemetry/constants.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Constants used in SageMaker Python SDK telemetry."""
14+
15+
from __future__ import absolute_import
16+
from enum import Enum
17+
18+
# Default AWS region used by SageMaker
19+
DEFAULT_AWS_REGION = "us-west-2"
20+
21+
22+
class Feature(Enum):
23+
"""Enumeration of feature names used in telemetry."""
24+
25+
SDK_DEFAULTS = 1
26+
LOCAL_MODE = 2
27+
28+
def __str__(self):
29+
"""Return the feature name."""
30+
return self.name
31+
32+
33+
class Status(Enum):
34+
"""Enumeration of status values used in telemetry."""
35+
36+
SUCCESS = 1
37+
FAILURE = 0
38+
39+
def __str__(self):
40+
"""Return the status name."""
41+
return self.name
Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Telemetry module for SageMaker Python SDK to collect usage data and metrics."""
14+
from __future__ import absolute_import
15+
import logging
16+
import platform
17+
import sys
18+
from time import perf_counter
19+
from typing import List
20+
21+
import requests
22+
23+
24+
from sagemaker.utils import resolve_value_from_config
25+
from sagemaker.config.config_schema import TELEMETRY_OPT_OUT_PATH
26+
from sagemaker.telemetry.constants import (
27+
Feature,
28+
Status,
29+
DEFAULT_AWS_REGION,
30+
)
31+
from sagemaker.user_agent import SDK_VERSION, process_studio_metadata_file
32+
33+
logger = logging.getLogger(__name__)
34+
35+
OS_NAME = platform.system() or "UnresolvedOS"
36+
OS_VERSION = platform.release() or "UnresolvedOSVersion"
37+
OS_NAME_VERSION = "{}/{}".format(OS_NAME, OS_VERSION)
38+
PYTHON_VERSION = "{}.{}.{}".format(
39+
sys.version_info.major, sys.version_info.minor, sys.version_info.micro
40+
)
41+
42+
TELEMETRY_OPT_OUT_MESSAGING = (
43+
"SageMaker Python SDK will collect telemetry to help us better understand our user's needs, "
44+
"diagnose issues, and deliver additional features.\n"
45+
"To opt out of telemetry, please disable via TelemetryOptOut parameter in SDK defaults config. "
46+
"For more information, refer to https://sagemaker.readthedocs.io/en/stable/overview.html"
47+
"#configuring-and-using-defaults-with-the-sagemaker-python-sdk."
48+
)
49+
50+
FEATURE_TO_CODE = {
51+
str(Feature.SDK_DEFAULTS): 1,
52+
str(Feature.LOCAL_MODE): 2,
53+
}
54+
55+
STATUS_TO_CODE = {
56+
str(Status.SUCCESS): 1,
57+
str(Status.FAILURE): 0,
58+
}
59+
60+
61+
def _telemetry_emitter(feature: str, func_name: str):
62+
"""Decorator to emit telemetry logs for SageMaker Python SDK functions"""
63+
64+
def decorator(func):
65+
def wrapper(self, *args, **kwargs):
66+
logger.info(TELEMETRY_OPT_OUT_MESSAGING)
67+
response = None
68+
caught_ex = None
69+
studio_app_type = process_studio_metadata_file()
70+
71+
# Check if telemetry is opted out
72+
telemetry_opt_out_flag = resolve_value_from_config(
73+
direct_input=None,
74+
config_path=TELEMETRY_OPT_OUT_PATH,
75+
default_value=False,
76+
sagemaker_session=self.sagemaker_session,
77+
)
78+
logger.debug("TelemetryOptOut flag is set to: %s", telemetry_opt_out_flag)
79+
80+
# Construct the feature list to track feature combinations
81+
feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]]
82+
if self.sagemaker_session:
83+
if self.sagemaker_session.sagemaker_config and feature != Feature.SDK_DEFAULTS:
84+
feature_list.append(FEATURE_TO_CODE[str(Feature.SDK_DEFAULTS)])
85+
86+
if self.sagemaker_session.local_mode and feature != Feature.LOCAL_MODE:
87+
feature_list.append(FEATURE_TO_CODE[str(Feature.LOCAL_MODE)])
88+
89+
# Construct the extra info to track platform and environment usage metadata
90+
extra = (
91+
f"{func_name}"
92+
f"&x-sdkVersion={SDK_VERSION}"
93+
f"&x-env={PYTHON_VERSION}"
94+
f"&x-sys={OS_NAME_VERSION}"
95+
f"&x-platform={studio_app_type}"
96+
)
97+
98+
# Add endpoint ARN to the extra info if available
99+
if self.sagemaker_session and self.sagemaker_session.endpoint_arn:
100+
extra += f"&x-endpointArn={self.sagemaker_session.endpoint_arn}"
101+
102+
start_timer = perf_counter()
103+
try:
104+
# Call the original function
105+
response = func(self, *args, **kwargs)
106+
stop_timer = perf_counter()
107+
elapsed = stop_timer - start_timer
108+
extra += f"&x-latency={round(elapsed, 2)}"
109+
if not telemetry_opt_out_flag:
110+
_send_telemetry_request(
111+
STATUS_TO_CODE[str(Status.SUCCESS)],
112+
feature_list,
113+
self.sagemaker_session,
114+
None,
115+
None,
116+
extra,
117+
)
118+
except Exception as e:
119+
stop_timer = perf_counter()
120+
elapsed = stop_timer - start_timer
121+
extra += f"&x-latency={round(elapsed, 2)}"
122+
if not telemetry_opt_out_flag:
123+
_send_telemetry_request(
124+
STATUS_TO_CODE[str(Status.FAILURE)],
125+
feature_list,
126+
self.sagemaker_session,
127+
str(e),
128+
e.__class__.__name__,
129+
extra,
130+
)
131+
caught_ex = e
132+
finally:
133+
if caught_ex:
134+
raise caught_ex
135+
return response # pylint: disable=W0150
136+
137+
return wrapper
138+
139+
return decorator
140+
141+
142+
from sagemaker.session import Session # noqa: E402
143+
144+
145+
def _send_telemetry_request(
146+
status: int,
147+
feature_list: List[int],
148+
session: Session,
149+
failure_reason: str = None,
150+
failure_type: str = None,
151+
extra_info: str = None,
152+
) -> None:
153+
"""Make GET request to an empty object in S3 bucket"""
154+
try:
155+
accountId = _get_accountId(session)
156+
region = _get_region_or_default(session)
157+
url = _construct_url(
158+
accountId,
159+
region,
160+
str(status),
161+
str(
162+
",".join(map(str, feature_list))
163+
), # Remove brackets and quotes to cut down on length
164+
failure_reason,
165+
failure_type,
166+
extra_info,
167+
)
168+
# Send the telemetry request
169+
logger.debug("Sending telemetry request to [%s]", url)
170+
_requests_helper(url, 2)
171+
logger.debug("SageMaker Python SDK telemetry successfully emitted!")
172+
except Exception: # pylint: disable=W0703
173+
logger.debug("SageMaker Python SDK telemetry not emitted!!")
174+
175+
176+
def _construct_url(
177+
accountId: str,
178+
region: str,
179+
status: str,
180+
feature: str,
181+
failure_reason: str,
182+
failure_type: str,
183+
extra_info: str,
184+
) -> str:
185+
"""Construct the URL for the telemetry request"""
186+
187+
base_url = (
188+
f"https://sm-pysdk-t-{region}.s3.{region}.amazonaws.com/telemetry?"
189+
f"x-accountId={accountId}"
190+
f"&x-status={status}"
191+
f"&x-feature={feature}"
192+
)
193+
logger.debug("Failure reason: %s", failure_reason)
194+
if failure_reason:
195+
base_url += f"&x-failureReason={failure_reason}"
196+
base_url += f"&x-failureType={failure_type}"
197+
if extra_info:
198+
base_url += f"&x-extra={extra_info}"
199+
return base_url
200+
201+
202+
def _requests_helper(url, timeout):
203+
"""Make a GET request to the given URL"""
204+
205+
response = None
206+
try:
207+
response = requests.get(url, timeout)
208+
except requests.exceptions.RequestException as e:
209+
logger.exception("Request exception: %s", str(e))
210+
return response
211+
212+
213+
def _get_accountId(session):
214+
"""Return the account ID from the boto session"""
215+
216+
try:
217+
sts = session.boto_session.client("sts")
218+
return sts.get_caller_identity()["Account"]
219+
except Exception: # pylint: disable=W0703
220+
return None
221+
222+
223+
def _get_region_or_default(session):
224+
"""Return the region name from the boto session or default to us-west-2"""
225+
226+
try:
227+
return session.boto_session.region_name
228+
except Exception: # pylint: disable=W0703
229+
return DEFAULT_AWS_REGION

‎tests/unit/sagemaker/local/test_local_session.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@
4747

4848
@patch("sagemaker.local.image._SageMakerContainer.process")
4949
@patch("sagemaker.local.local_session.LocalSession")
50-
def test_create_processing_job(process, LocalSession):
50+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
51+
def test_create_processing_job(process, LocalSession, mock_telemetry):
5152
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
5253

5354
instance_count = 2
@@ -142,7 +143,8 @@ def test_create_processing_job(process, LocalSession):
142143

143144
@patch("sagemaker.local.image._SageMakerContainer.process")
144145
@patch("sagemaker.local.local_session.LocalSession")
145-
def test_create_processing_job_not_fully_replicated(process, LocalSession):
146+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
147+
def test_create_processing_job_not_fully_replicated(process, LocalSession, mock_telemetry):
146148
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
147149

148150
instance_count = 2
@@ -197,7 +199,8 @@ def test_create_processing_job_not_fully_replicated(process, LocalSession):
197199

198200
@patch("sagemaker.local.image._SageMakerContainer.process")
199201
@patch("sagemaker.local.local_session.LocalSession")
200-
def test_create_processing_job_invalid_upload_mode(process, LocalSession):
202+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
203+
def test_create_processing_job_invalid_upload_mode(process, LocalSession, mock_telemetry):
201204
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
202205

203206
instance_count = 2
@@ -252,7 +255,8 @@ def test_create_processing_job_invalid_upload_mode(process, LocalSession):
252255

253256
@patch("sagemaker.local.image._SageMakerContainer.process")
254257
@patch("sagemaker.local.local_session.LocalSession")
255-
def test_create_processing_job_invalid_processing_input(process, LocalSession):
258+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
259+
def test_create_processing_job_invalid_processing_input(process, LocalSession, mock_telemetry):
256260
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
257261

258262
instance_count = 2
@@ -302,7 +306,8 @@ def test_create_processing_job_invalid_processing_input(process, LocalSession):
302306

303307
@patch("sagemaker.local.image._SageMakerContainer.process")
304308
@patch("sagemaker.local.local_session.LocalSession")
305-
def test_create_processing_job_invalid_processing_output(process, LocalSession):
309+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
310+
def test_create_processing_job_invalid_processing_output(process, LocalSession, mock_telemetry):
306311
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
307312

308313
instance_count = 2
@@ -360,7 +365,8 @@ def test_describe_invalid_processing_job(*args):
360365

361366
@patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model")
362367
@patch("sagemaker.local.local_session.LocalSession")
363-
def test_create_training_job(train, LocalSession):
368+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
369+
def test_create_training_job(train, LocalSession, mock_telemetry):
364370
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
365371

366372
instance_count = 2
@@ -427,7 +433,8 @@ def test_describe_invalid_training_job(*args):
427433

428434
@patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model")
429435
@patch("sagemaker.local.local_session.LocalSession")
430-
def test_create_training_job_invalid_data_source(train, LocalSession):
436+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
437+
def test_create_training_job_invalid_data_source(train, LocalSession, mock_telemetry):
431438
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
432439

433440
instance_count = 2
@@ -466,7 +473,8 @@ def test_create_training_job_invalid_data_source(train, LocalSession):
466473

467474
@patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model")
468475
@patch("sagemaker.local.local_session.LocalSession")
469-
def test_create_training_job_not_fully_replicated(train, LocalSession):
476+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
477+
def test_create_training_job_not_fully_replicated(train, LocalSession, mock_telemetry):
470478
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
471479

472480
instance_count = 2
@@ -503,7 +511,8 @@ def test_create_training_job_not_fully_replicated(train, LocalSession):
503511

504512

505513
@patch("sagemaker.local.local_session.LocalSession")
506-
def test_create_model(LocalSession):
514+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
515+
def test_create_model(LocalSession, mock_telemetry):
507516
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
508517

509518
local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER)
@@ -512,7 +521,8 @@ def test_create_model(LocalSession):
512521

513522

514523
@patch("sagemaker.local.local_session.LocalSession")
515-
def test_delete_model(LocalSession):
524+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
525+
def test_delete_model(LocalSession, mock_telemetry):
516526
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
517527

518528
local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER)
@@ -523,7 +533,8 @@ def test_delete_model(LocalSession):
523533

524534

525535
@patch("sagemaker.local.local_session.LocalSession")
526-
def test_describe_model(LocalSession):
536+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
537+
def test_describe_model(LocalSession, mock_telemetry):
527538
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
528539

529540
with pytest.raises(ClientError):
@@ -536,9 +547,10 @@ def test_describe_model(LocalSession):
536547
assert response["PrimaryContainer"]["ModelDataUrl"] == "/some/model/path"
537548

538549

550+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
539551
@patch("sagemaker.local.local_session._LocalTransformJob")
540552
@patch("sagemaker.local.local_session.LocalSession")
541-
def test_create_transform_job(LocalSession, _LocalTransformJob):
553+
def test_create_transform_job(LocalSession, _LocalTransformJob, mock_telemetry):
542554
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
543555

544556
local_sagemaker_client.create_transform_job("transform-job", "some-model", None, None, None)
@@ -572,7 +584,8 @@ def test_logs_for_processing_job(process, LocalSession):
572584

573585

574586
@patch("sagemaker.local.local_session.LocalSession")
575-
def test_describe_endpoint_config(LocalSession):
587+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
588+
def test_describe_endpoint_config(LocalSession, mock_telemetry):
576589
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
577590

578591
# No Endpoint Config Created
@@ -588,7 +601,8 @@ def test_describe_endpoint_config(LocalSession):
588601

589602

590603
@patch("sagemaker.local.local_session.LocalSession")
591-
def test_create_endpoint_config(LocalSession):
604+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
605+
def test_create_endpoint_config(LocalSession, mock_telemetry):
592606
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
593607
local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS)
594608

@@ -598,7 +612,8 @@ def test_create_endpoint_config(LocalSession):
598612

599613

600614
@patch("sagemaker.local.local_session.LocalSession")
601-
def test_delete_endpoint_config(LocalSession):
615+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
616+
def test_delete_endpoint_config(LocalSession, mock_telemetry):
602617
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
603618

604619
local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS)
@@ -613,12 +628,15 @@ def test_delete_endpoint_config(LocalSession):
613628
)
614629

615630

631+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
616632
@patch("sagemaker.local.image._SageMakerContainer.serve")
617633
@patch("sagemaker.local.local_session.LocalSession")
618634
@patch("urllib3.PoolManager.request")
619635
@patch("sagemaker.local.local_session.LocalSagemakerClient.describe_endpoint_config")
620636
@patch("sagemaker.local.local_session.LocalSagemakerClient.describe_model")
621-
def test_describe_endpoint(describe_model, describe_endpoint_config, request, *args):
637+
def test_describe_endpoint(
638+
describe_model, describe_endpoint_config, request, mock_telemetry, *args
639+
):
622640
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
623641

624642
request.return_value = OK_RESPONSE
@@ -658,12 +676,13 @@ def test_describe_endpoint(describe_model, describe_endpoint_config, request, *a
658676
assert response["EndpointName"] == "test-endpoint"
659677

660678

679+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None)
661680
@patch("sagemaker.local.image._SageMakerContainer.serve")
662681
@patch("sagemaker.local.local_session.LocalSession")
663682
@patch("urllib3.PoolManager.request")
664683
@patch("sagemaker.local.local_session.LocalSagemakerClient.describe_endpoint_config")
665684
@patch("sagemaker.local.local_session.LocalSagemakerClient.describe_model")
666-
def test_create_endpoint(describe_model, describe_endpoint_config, request, *args):
685+
def test_create_endpoint(describe_model, describe_endpoint_config, request, mock_telemetry, *args):
667686
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
668687

669688
request.return_value = OK_RESPONSE
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
import unittest
15+
import pytest
16+
import requests
17+
from unittest.mock import Mock, patch, MagicMock
18+
import sagemaker
19+
from sagemaker.telemetry.constants import Feature
20+
from sagemaker.telemetry.telemetry_logging import (
21+
_send_telemetry_request,
22+
_telemetry_emitter,
23+
_construct_url,
24+
_get_accountId,
25+
_requests_helper,
26+
_get_region_or_default,
27+
OS_NAME_VERSION,
28+
PYTHON_VERSION,
29+
)
30+
from sagemaker.user_agent import SDK_VERSION, process_studio_metadata_file
31+
from sagemaker.serve.utils.exceptions import ModelBuilderException, LocalModelOutOfMemoryException
32+
33+
MOCK_SESSION = Mock()
34+
MOCK_EXCEPTION = LocalModelOutOfMemoryException("mock raise ex")
35+
MOCK_FEATURE = Feature.SDK_DEFAULTS
36+
MOCK_FUNC_NAME = "Mock.local_session.create_model"
37+
MOCK_ENDPOINT_ARN = "arn:aws:sagemaker:us-west-2:123456789012:endpoint/test"
38+
39+
40+
class LocalSagemakerClientMock:
41+
def __init__(self):
42+
self.sagemaker_session = MOCK_SESSION
43+
44+
@_telemetry_emitter(MOCK_FEATURE, MOCK_FUNC_NAME)
45+
def mock_create_model(self, mock_exception_func=None):
46+
if mock_exception_func:
47+
mock_exception_func()
48+
49+
50+
class TestTelemetryLogging(unittest.TestCase):
51+
@patch("sagemaker.telemetry.telemetry_logging._requests_helper")
52+
@patch("sagemaker.telemetry.telemetry_logging._get_accountId")
53+
def test_log_sucessfully(self, mock_get_accountId, mock_request_helper):
54+
"""Test to check if the telemetry logging is successful"""
55+
MOCK_SESSION.boto_session.region_name = "us-west-2"
56+
mock_get_accountId.return_value = "testAccountId"
57+
_send_telemetry_request("someStatus", "1", MOCK_SESSION)
58+
mock_request_helper.assert_called_with(
59+
"https://sm-pysdk-t-us-west-2.s3.us-west-2.amazonaws.com/"
60+
"telemetry?x-accountId=testAccountId&x-status=someStatus&x-feature=1",
61+
2,
62+
)
63+
64+
@patch("sagemaker.telemetry.telemetry_logging._get_accountId")
65+
def test_log_handle_exception(self, mock_get_accountId):
66+
"""Test to check if the exception is handled while logging telemetry"""
67+
mock_get_accountId.side_effect = Exception("Internal error")
68+
_send_telemetry_request("someStatus", "1", MOCK_SESSION)
69+
self.assertRaises(Exception)
70+
71+
@patch("sagemaker.telemetry.telemetry_logging._get_accountId")
72+
@patch("sagemaker.telemetry.telemetry_logging._get_region_or_default")
73+
def test_send_telemetry_request_success(self, mock_get_region, mock_get_accountId):
74+
"""Test to check the _send_telemetry_request function with success status"""
75+
mock_get_accountId.return_value = "testAccountId"
76+
mock_get_region.return_value = "us-west-2"
77+
78+
with patch(
79+
"sagemaker.telemetry.telemetry_logging._requests_helper"
80+
) as mock_requests_helper:
81+
mock_requests_helper.return_value = None
82+
_send_telemetry_request(1, [1, 2], MagicMock(), None, None, "extra_info")
83+
mock_requests_helper.assert_called_with(
84+
"https://sm-pysdk-t-us-west-2.s3.us-west-2.amazonaws.com/"
85+
"telemetry?x-accountId=testAccountId&x-status=1&x-feature=1,2&x-extra=extra_info",
86+
2,
87+
)
88+
89+
@patch("sagemaker.telemetry.telemetry_logging._get_accountId")
90+
@patch("sagemaker.telemetry.telemetry_logging._get_region_or_default")
91+
def test_send_telemetry_request_failure(self, mock_get_region, mock_get_accountId):
92+
"""Test to check the _send_telemetry_request function with failure status"""
93+
mock_get_accountId.return_value = "testAccountId"
94+
mock_get_region.return_value = "us-west-2"
95+
96+
with patch(
97+
"sagemaker.telemetry.telemetry_logging._requests_helper"
98+
) as mock_requests_helper:
99+
mock_requests_helper.return_value = None
100+
_send_telemetry_request(
101+
0, [1, 2], MagicMock(), "failure_reason", "failure_type", "extra_info"
102+
)
103+
mock_requests_helper.assert_called_with(
104+
"https://sm-pysdk-t-us-west-2.s3.us-west-2.amazonaws.com/"
105+
"telemetry?x-accountId=testAccountId&x-status=0&x-feature=1,2"
106+
"&x-failureReason=failure_reason&x-failureType=failure_type&x-extra=extra_info",
107+
2,
108+
)
109+
110+
@patch("sagemaker.telemetry.telemetry_logging._send_telemetry_request")
111+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config")
112+
def test_telemetry_emitter_decorator_no_call_when_disabled(
113+
self, mock_resolve_config, mock_send_telemetry_request
114+
):
115+
"""Test to check if the _telemetry_emitter decorator is not called when telemetry is disabled"""
116+
mock_resolve_config.return_value = True
117+
118+
assert not mock_send_telemetry_request.called
119+
120+
@patch("sagemaker.telemetry.telemetry_logging._send_telemetry_request")
121+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config")
122+
def test_telemetry_emitter_decorator_success(
123+
self, mock_resolve_config, mock_send_telemetry_request
124+
):
125+
"""Test to verify the _telemetry_emitter decorator with success status"""
126+
mock_resolve_config.return_value = False
127+
mock_local_client = LocalSagemakerClientMock()
128+
mock_local_client.sagemaker_session.endpoint_arn = MOCK_ENDPOINT_ARN
129+
mock_local_client.mock_create_model()
130+
app_type = process_studio_metadata_file()
131+
132+
args = mock_send_telemetry_request.call_args.args
133+
latency = str(args[5]).split("latency=")[1]
134+
expected_extra_str = (
135+
f"{MOCK_FUNC_NAME}"
136+
f"&x-sdkVersion={SDK_VERSION}"
137+
f"&x-env={PYTHON_VERSION}"
138+
f"&x-sys={OS_NAME_VERSION}"
139+
f"&x-platform={app_type}"
140+
f"&x-endpointArn={MOCK_ENDPOINT_ARN}"
141+
f"&x-latency={latency}"
142+
)
143+
144+
mock_send_telemetry_request.assert_called_once_with(
145+
1, [1, 2], MOCK_SESSION, None, None, expected_extra_str
146+
)
147+
148+
@patch("sagemaker.telemetry.telemetry_logging._send_telemetry_request")
149+
@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config")
150+
def test_telemetry_emitter_decorator_handle_exception_success(
151+
self, mock_resolve_config, mock_send_telemetry_request
152+
):
153+
"""Test to verify the _telemetry_emitter decorator when function emits exception"""
154+
mock_resolve_config.return_value = False
155+
mock_local_client = LocalSagemakerClientMock()
156+
mock_local_client.sagemaker_session.endpoint_arn = MOCK_ENDPOINT_ARN
157+
app_type = process_studio_metadata_file()
158+
159+
mock_exception = Mock()
160+
mock_exception_obj = MOCK_EXCEPTION
161+
mock_exception.side_effect = mock_exception_obj
162+
163+
with self.assertRaises(ModelBuilderException) as _:
164+
mock_local_client.mock_create_model(mock_exception)
165+
166+
args = mock_send_telemetry_request.call_args.args
167+
latency = str(args[5]).split("latency=")[1]
168+
expected_extra_str = (
169+
f"{MOCK_FUNC_NAME}"
170+
f"&x-sdkVersion={SDK_VERSION}"
171+
f"&x-env={PYTHON_VERSION}"
172+
f"&x-sys={OS_NAME_VERSION}"
173+
f"&x-platform={app_type}"
174+
f"&x-endpointArn={MOCK_ENDPOINT_ARN}"
175+
f"&x-latency={latency}"
176+
)
177+
178+
mock_send_telemetry_request.assert_called_once_with(
179+
0,
180+
[1, 2],
181+
MOCK_SESSION,
182+
str(mock_exception_obj),
183+
mock_exception_obj.__class__.__name__,
184+
expected_extra_str,
185+
)
186+
187+
def test_construct_url_with_failure_reason_and_extra_info(self):
188+
"""Test to verify the _construct_url function with failure reason and extra info"""
189+
mock_accountId = "testAccountId"
190+
mock_status = 0
191+
mock_feature = "1,2"
192+
mock_failure_reason = str(MOCK_EXCEPTION)
193+
mock_failure_type = MOCK_EXCEPTION.__class__.__name__
194+
mock_extra_info = "mock_extra_info"
195+
mock_region = "us-west-2"
196+
197+
resulted_url = _construct_url(
198+
accountId=mock_accountId,
199+
region=mock_region,
200+
status=mock_status,
201+
feature=mock_feature,
202+
failure_reason=mock_failure_reason,
203+
failure_type=mock_failure_type,
204+
extra_info=mock_extra_info,
205+
)
206+
207+
expected_base_url = (
208+
f"https://sm-pysdk-t-{mock_region}.s3.{mock_region}.amazonaws.com/telemetry?"
209+
f"x-accountId={mock_accountId}"
210+
f"&x-status={mock_status}"
211+
f"&x-feature={mock_feature}"
212+
f"&x-failureReason={mock_failure_reason}"
213+
f"&x-failureType={mock_failure_type}"
214+
f"&x-extra={mock_extra_info}"
215+
)
216+
self.assertEqual(resulted_url, expected_base_url)
217+
218+
@patch("sagemaker.telemetry.telemetry_logging.requests.get")
219+
def test_requests_helper_success(self, mock_requests_get):
220+
"""Test to verify the _requests_helper function with success status"""
221+
mock_response = MagicMock()
222+
mock_response.status_code = 200
223+
mock_requests_get.return_value = mock_response
224+
url = "https://example.com"
225+
timeout = 10
226+
227+
response = _requests_helper(url, timeout)
228+
229+
mock_requests_get.assert_called_once_with(url, timeout)
230+
self.assertEqual(response, mock_response)
231+
232+
@patch("sagemaker.telemetry.telemetry_logging.requests.get")
233+
def test_requests_helper_exception(self, mock_requests_get):
234+
"""Test to verify the _requests_helper function with exception"""
235+
mock_requests_get.side_effect = requests.exceptions.RequestException("Error making request")
236+
url = "https://example.com"
237+
timeout = 10
238+
239+
response = _requests_helper(url, timeout)
240+
241+
mock_requests_get.assert_called_once_with(url, timeout)
242+
self.assertIsNone(response)
243+
244+
def test_get_accountId_success(self):
245+
"""Test to verify the _get_accountId function with success status"""
246+
boto_mock = MagicMock(name="boto_session")
247+
boto_mock.client("sts").get_caller_identity.return_value = {"Account": "testAccountId"}
248+
session = sagemaker.Session(boto_session=boto_mock)
249+
account_id = _get_accountId(session)
250+
251+
self.assertEqual(account_id, "testAccountId")
252+
253+
def test_get_accountId_exception(self):
254+
"""Test to verify the _get_accountId function with exception"""
255+
sts_client_mock = MagicMock()
256+
sts_client_mock.side_effect = Exception("Error creating STS client")
257+
boto_mock = MagicMock(name="boto_session")
258+
boto_mock.client("sts").get_caller_identity.return_value = sts_client_mock
259+
session = sagemaker.Session(boto_session=boto_mock)
260+
261+
with pytest.raises(Exception) as exception:
262+
account_id = _get_accountId(session)
263+
assert account_id is None
264+
assert "Error creating STS client" in str(exception)
265+
266+
def test_get_region_or_default_success(self):
267+
"""Test to verify the _get_region_or_default function with success status"""
268+
mock_session = MagicMock()
269+
mock_session.boto_session = MagicMock(region_name="us-east-1")
270+
271+
region = _get_region_or_default(mock_session)
272+
273+
assert region == "us-east-1"
274+
275+
def test_get_region_or_default_exception(self):
276+
"""Test to verify the _get_region_or_default function with exception"""
277+
mock_session = MagicMock()
278+
mock_session.boto_session = MagicMock()
279+
mock_session.boto_session.region_name.side_effect = Exception("Error creating boto session")
280+
281+
with pytest.raises(Exception) as exception:
282+
region = _get_region_or_default(mock_session)
283+
assert region == "us-west-2"
284+
assert "Error creating boto session" in str(exception)

0 commit comments

Comments
 (0)
Please sign in to comment.