|
| 1 | +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""Telemetry module for SageMaker Python SDK to collect usage data and metrics.""" |
| 14 | +from __future__ import absolute_import |
| 15 | +import logging |
| 16 | +import platform |
| 17 | +import sys |
| 18 | +from time import perf_counter |
| 19 | +from typing import List |
| 20 | + |
| 21 | +from sagemaker.utils import resolve_value_from_config |
| 22 | +from sagemaker.config.config_schema import TELEMETRY_OPT_OUT_PATH |
| 23 | +from sagemaker.telemetry.constants import ( |
| 24 | + Feature, |
| 25 | + Status, |
| 26 | + DEFAULT_AWS_REGION, |
| 27 | +) |
| 28 | +from sagemaker.user_agent import SDK_VERSION, process_studio_metadata_file |
| 29 | + |
| 30 | +logger = logging.getLogger(__name__) |
| 31 | + |
| 32 | +OS_NAME = platform.system() or "UnresolvedOS" |
| 33 | +OS_VERSION = platform.release() or "UnresolvedOSVersion" |
| 34 | +OS_NAME_VERSION = "{}/{}".format(OS_NAME, OS_VERSION) |
| 35 | +PYTHON_VERSION = "{}.{}.{}".format( |
| 36 | + sys.version_info.major, sys.version_info.minor, sys.version_info.micro |
| 37 | +) |
| 38 | + |
| 39 | +TELEMETRY_OPT_OUT_MESSAGING = ( |
| 40 | + "SageMaker Python SDK will collect telemetry to help us better understand our user's needs, " |
| 41 | + "diagnose issues, and deliver additional features.\n" |
| 42 | + "To opt out of telemetry, please disable via TelemetryOptOut parameter in SDK defaults config. " |
| 43 | + "For more information, refer to https://sagemaker.readthedocs.io/en/stable/overview.html" |
| 44 | + "#configuring-and-using-defaults-with-the-sagemaker-python-sdk." |
| 45 | +) |
| 46 | + |
| 47 | +FEATURE_TO_CODE = { |
| 48 | + str(Feature.SDK_DEFAULTS): 1, |
| 49 | + str(Feature.LOCAL_MODE): 2, |
| 50 | +} |
| 51 | + |
| 52 | +STATUS_TO_CODE = { |
| 53 | + str(Status.SUCCESS): 1, |
| 54 | + str(Status.FAILURE): 0, |
| 55 | +} |
| 56 | + |
| 57 | + |
| 58 | +def _telemetry_emitter(feature: str, func_name: str): |
| 59 | + """Decorator to emit telemetry logs for SageMaker Python SDK functions""" |
| 60 | + |
| 61 | + def decorator(func): |
| 62 | + def wrapper(self, *args, **kwargs): |
| 63 | + logger.info(TELEMETRY_OPT_OUT_MESSAGING) |
| 64 | + response = None |
| 65 | + caught_ex = None |
| 66 | + studio_app_type = process_studio_metadata_file() |
| 67 | + |
| 68 | + # Check if telemetry is opted out |
| 69 | + telemetry_opt_out_flag = resolve_value_from_config( |
| 70 | + direct_input=None, |
| 71 | + config_path=TELEMETRY_OPT_OUT_PATH, |
| 72 | + default_value=False, |
| 73 | + sagemaker_session=self.sagemaker_session, |
| 74 | + ) |
| 75 | + logger.debug("TelemetryOptOut flag is set to: %s", telemetry_opt_out_flag) |
| 76 | + |
| 77 | + # Construct the feature list to track feature combinations |
| 78 | + feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]] |
| 79 | + if self.sagemaker_session: |
| 80 | + if self.sagemaker_session.sagemaker_config and feature != Feature.SDK_DEFAULTS: |
| 81 | + feature_list.append(FEATURE_TO_CODE[str(Feature.SDK_DEFAULTS)]) |
| 82 | + |
| 83 | + if self.sagemaker_session.local_mode and feature != Feature.LOCAL_MODE: |
| 84 | + feature_list.append(FEATURE_TO_CODE[str(Feature.LOCAL_MODE)]) |
| 85 | + |
| 86 | + # Construct the extra info to track platform and environment usage metadata |
| 87 | + extra = ( |
| 88 | + f"{func_name}" |
| 89 | + f"&x-sdkVersion={SDK_VERSION}" |
| 90 | + f"&x-env={PYTHON_VERSION}" |
| 91 | + f"&x-sys={OS_NAME_VERSION}" |
| 92 | + f"&x-platform={studio_app_type}" |
| 93 | + ) |
| 94 | + |
| 95 | + # Add endpoint ARN to the extra info if available |
| 96 | + if self.sagemaker_session and self.sagemaker_session.endpoint_arn: |
| 97 | + extra += f"&x-endpointArn={self.sagemaker_session.endpoint_arn}" |
| 98 | + |
| 99 | + start_timer = perf_counter() |
| 100 | + try: |
| 101 | + # Call the original function |
| 102 | + response = func(self, *args, **kwargs) |
| 103 | + stop_timer = perf_counter() |
| 104 | + elapsed = stop_timer - start_timer |
| 105 | + extra += f"&x-latency={round(elapsed, 2)}" |
| 106 | + if not telemetry_opt_out_flag: |
| 107 | + _send_telemetry_request( |
| 108 | + STATUS_TO_CODE[str(Status.SUCCESS)], |
| 109 | + feature_list, |
| 110 | + self.sagemaker_session, |
| 111 | + None, |
| 112 | + None, |
| 113 | + extra, |
| 114 | + ) |
| 115 | + except Exception as e: # pylint: disable=W0703 |
| 116 | + stop_timer = perf_counter() |
| 117 | + elapsed = stop_timer - start_timer |
| 118 | + extra += f"&x-latency={round(elapsed, 2)}" |
| 119 | + if not telemetry_opt_out_flag: |
| 120 | + _send_telemetry_request( |
| 121 | + STATUS_TO_CODE[str(Status.FAILURE)], |
| 122 | + feature_list, |
| 123 | + self.sagemaker_session, |
| 124 | + str(e), |
| 125 | + e.__class__.__name__, |
| 126 | + extra, |
| 127 | + ) |
| 128 | + caught_ex = e |
| 129 | + finally: |
| 130 | + if caught_ex: |
| 131 | + raise caught_ex |
| 132 | + return response # pylint: disable=W0150 |
| 133 | + |
| 134 | + return wrapper |
| 135 | + |
| 136 | + return decorator |
| 137 | + |
| 138 | + |
| 139 | +from sagemaker.session import Session # noqa: E402 pylint: disable=C0413 |
| 140 | + |
| 141 | + |
| 142 | +def _send_telemetry_request( |
| 143 | + status: int, |
| 144 | + feature_list: List[int], |
| 145 | + session: Session, |
| 146 | + failure_reason: str = None, |
| 147 | + failure_type: str = None, |
| 148 | + extra_info: str = None, |
| 149 | +) -> None: |
| 150 | + """Make GET request to an empty object in S3 bucket""" |
| 151 | + try: |
| 152 | + accountId = _get_accountId(session) |
| 153 | + region = _get_region_or_default(session) |
| 154 | + url = _construct_url( |
| 155 | + accountId, |
| 156 | + region, |
| 157 | + str(status), |
| 158 | + str( |
| 159 | + ",".join(map(str, feature_list)) |
| 160 | + ), # Remove brackets and quotes to cut down on length |
| 161 | + failure_reason, |
| 162 | + failure_type, |
| 163 | + extra_info, |
| 164 | + ) |
| 165 | + # Send the telemetry request |
| 166 | + logger.debug("Sending telemetry request to [%s]", url) |
| 167 | + _requests_helper(url, 2) |
| 168 | + logger.debug("SageMaker Python SDK telemetry successfully emitted!") |
| 169 | + except Exception: # pylint: disable=W0703 |
| 170 | + logger.debug("SageMaker Python SDK telemetry not emitted!!") |
| 171 | + |
| 172 | + |
| 173 | +def _construct_url( |
| 174 | + accountId: str, |
| 175 | + region: str, |
| 176 | + status: str, |
| 177 | + feature: str, |
| 178 | + failure_reason: str, |
| 179 | + failure_type: str, |
| 180 | + extra_info: str, |
| 181 | +) -> str: |
| 182 | + """Construct the URL for the telemetry request""" |
| 183 | + |
| 184 | + base_url = ( |
| 185 | + f"https://sm-pysdk-t-{region}.s3.{region}.amazonaws.com/telemetry?" |
| 186 | + f"x-accountId={accountId}" |
| 187 | + f"&x-status={status}" |
| 188 | + f"&x-feature={feature}" |
| 189 | + ) |
| 190 | + logger.debug("Failure reason: %s", failure_reason) |
| 191 | + if failure_reason: |
| 192 | + base_url += f"&x-failureReason={failure_reason}" |
| 193 | + base_url += f"&x-failureType={failure_type}" |
| 194 | + if extra_info: |
| 195 | + base_url += f"&x-extra={extra_info}" |
| 196 | + return base_url |
| 197 | + |
| 198 | + |
| 199 | +import requests # noqa: E402 pylint: disable=C0413,C0411 |
| 200 | + |
| 201 | + |
| 202 | +def _requests_helper(url, timeout): |
| 203 | + """Make a GET request to the given URL""" |
| 204 | + |
| 205 | + response = None |
| 206 | + try: |
| 207 | + response = requests.get(url, timeout) |
| 208 | + except requests.exceptions.RequestException as e: |
| 209 | + logger.exception("Request exception: %s", str(e)) |
| 210 | + return response |
| 211 | + |
| 212 | + |
| 213 | +def _get_accountId(session): |
| 214 | + """Return the account ID from the boto session""" |
| 215 | + |
| 216 | + try: |
| 217 | + sts = session.boto_session.client("sts") |
| 218 | + return sts.get_caller_identity()["Account"] |
| 219 | + except Exception: # pylint: disable=W0703 |
| 220 | + return None |
| 221 | + |
| 222 | + |
| 223 | +def _get_region_or_default(session): |
| 224 | + """Return the region name from the boto session or default to us-west-2""" |
| 225 | + |
| 226 | + try: |
| 227 | + return session.boto_session.region_name |
| 228 | + except Exception: # pylint: disable=W0703 |
| 229 | + return DEFAULT_AWS_REGION |
0 commit comments