diff --git a/README.rst b/README.rst index e59b2da9c5..bd77e5884a 100644 --- a/README.rst +++ b/README.rst @@ -95,6 +95,13 @@ SageMaker Python SDK is tested on: - Python 3.10 - Python 3.11 +Telemetry +~~~~~~~~~~~~~~~ + +The ``sagemaker`` library has telemetry enabled to help us better understand user needs, diagnose issues, and deliver new features. This telemetry tracks the usage of various SageMaker functions. + +If you prefer to opt out of telemetry, you can easily do so by setting the ``TelemetryOptOut`` parameter to ``true`` in the SDK defaults configuration. For detailed instructions, please visit `Configuring and using defaults with the SageMaker Python SDK `__. + AWS Permissions ~~~~~~~~~~~~~~~ diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index 36a848aa52..89a2df2135 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -42,6 +42,8 @@ _LocalPipeline, ) from sagemaker.session import Session +from sagemaker.telemetry.telemetry_logging import _telemetry_emitter +from sagemaker.telemetry.constants import Feature from sagemaker.utils import ( get_config_value, _module_import_error, @@ -83,6 +85,7 @@ def __init__(self, sagemaker_session=None): """ self.sagemaker_session = sagemaker_session or LocalSession() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_processing_job") def create_processing_job( self, ProcessingJobName, @@ -165,6 +168,7 @@ def describe_processing_job(self, ProcessingJobName): raise ClientError(error_response, "describe_processing_job") return LocalSagemakerClient._processing_jobs[ProcessingJobName].describe() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_training_job") def create_training_job( self, TrainingJobName, @@ -235,6 +239,7 @@ def describe_training_job(self, TrainingJobName): raise ClientError(error_response, "describe_training_job") return LocalSagemakerClient._training_jobs[TrainingJobName].describe() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_transform_job") def create_transform_job( self, TransformJobName, @@ -280,6 +285,7 @@ def describe_transform_job(self, TransformJobName): raise ClientError(error_response, "describe_transform_job") return LocalSagemakerClient._transform_jobs[TransformJobName].describe() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_model") def create_model( self, ModelName, PrimaryContainer, *args, **kwargs ): # pylint: disable=unused-argument @@ -329,6 +335,7 @@ def describe_endpoint_config(self, EndpointConfigName): raise ClientError(error_response, "describe_endpoint_config") return LocalSagemakerClient._endpoint_configs[EndpointConfigName].describe() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_endpoint_config") def create_endpoint_config(self, EndpointConfigName, ProductionVariants, Tags=None): """Create the endpoint configuration. @@ -360,6 +367,7 @@ def describe_endpoint(self, EndpointName): raise ClientError(error_response, "describe_endpoint") return LocalSagemakerClient._endpoints[EndpointName].describe() + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_endpoint") def create_endpoint(self, EndpointName, EndpointConfigName, Tags=None): """Create the endpoint. @@ -428,6 +436,7 @@ def delete_model(self, ModelName): if ModelName in LocalSagemakerClient._models: del LocalSagemakerClient._models[ModelName] + @_telemetry_emitter(Feature.LOCAL_MODE, "local_session.create_pipeline") def create_pipeline( self, pipeline, pipeline_description, **kwargs # pylint: disable=unused-argument ): diff --git a/src/sagemaker/telemetry/__init__.py b/src/sagemaker/telemetry/__init__.py new file mode 100644 index 0000000000..ada3f1f09f --- /dev/null +++ b/src/sagemaker/telemetry/__init__.py @@ -0,0 +1,16 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Placeholder docstring""" +from __future__ import absolute_import + +from .telemetry_logging import _telemetry_emitter # noqa: F401 diff --git a/src/sagemaker/telemetry/constants.py b/src/sagemaker/telemetry/constants.py new file mode 100644 index 0000000000..f891c226d8 --- /dev/null +++ b/src/sagemaker/telemetry/constants.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Constants used in SageMaker Python SDK telemetry.""" + +from __future__ import absolute_import +from enum import Enum + +# Default AWS region used by SageMaker +DEFAULT_AWS_REGION = "us-west-2" + + +class Feature(Enum): + """Enumeration of feature names used in telemetry.""" + + SDK_DEFAULTS = 1 + LOCAL_MODE = 2 + + def __str__(self): # pylint: disable=E0307 + """Return the feature name.""" + return self.name + + +class Status(Enum): + """Enumeration of status values used in telemetry.""" + + SUCCESS = 1 + FAILURE = 0 + + def __str__(self): # pylint: disable=E0307 + """Return the status name.""" + return self.name diff --git a/src/sagemaker/telemetry/telemetry_logging.py b/src/sagemaker/telemetry/telemetry_logging.py new file mode 100644 index 0000000000..59ad58d16a --- /dev/null +++ b/src/sagemaker/telemetry/telemetry_logging.py @@ -0,0 +1,229 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Telemetry module for SageMaker Python SDK to collect usage data and metrics.""" +from __future__ import absolute_import +import logging +import platform +import sys +from time import perf_counter +from typing import List + +from sagemaker.utils import resolve_value_from_config +from sagemaker.config.config_schema import TELEMETRY_OPT_OUT_PATH +from sagemaker.telemetry.constants import ( + Feature, + Status, + DEFAULT_AWS_REGION, +) +from sagemaker.user_agent import SDK_VERSION, process_studio_metadata_file + +logger = logging.getLogger(__name__) + +OS_NAME = platform.system() or "UnresolvedOS" +OS_VERSION = platform.release() or "UnresolvedOSVersion" +OS_NAME_VERSION = "{}/{}".format(OS_NAME, OS_VERSION) +PYTHON_VERSION = "{}.{}.{}".format( + sys.version_info.major, sys.version_info.minor, sys.version_info.micro +) + +TELEMETRY_OPT_OUT_MESSAGING = ( + "SageMaker Python SDK will collect telemetry to help us better understand our user's needs, " + "diagnose issues, and deliver additional features.\n" + "To opt out of telemetry, please disable via TelemetryOptOut parameter in SDK defaults config. " + "For more information, refer to https://sagemaker.readthedocs.io/en/stable/overview.html" + "#configuring-and-using-defaults-with-the-sagemaker-python-sdk." +) + +FEATURE_TO_CODE = { + str(Feature.SDK_DEFAULTS): 1, + str(Feature.LOCAL_MODE): 2, +} + +STATUS_TO_CODE = { + str(Status.SUCCESS): 1, + str(Status.FAILURE): 0, +} + + +def _telemetry_emitter(feature: str, func_name: str): + """Decorator to emit telemetry logs for SageMaker Python SDK functions""" + + def decorator(func): + def wrapper(self, *args, **kwargs): + logger.info(TELEMETRY_OPT_OUT_MESSAGING) + response = None + caught_ex = None + studio_app_type = process_studio_metadata_file() + + # Check if telemetry is opted out + telemetry_opt_out_flag = resolve_value_from_config( + direct_input=None, + config_path=TELEMETRY_OPT_OUT_PATH, + default_value=False, + sagemaker_session=self.sagemaker_session, + ) + logger.debug("TelemetryOptOut flag is set to: %s", telemetry_opt_out_flag) + + # Construct the feature list to track feature combinations + feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]] + if self.sagemaker_session: + if self.sagemaker_session.sagemaker_config and feature != Feature.SDK_DEFAULTS: + feature_list.append(FEATURE_TO_CODE[str(Feature.SDK_DEFAULTS)]) + + if self.sagemaker_session.local_mode and feature != Feature.LOCAL_MODE: + feature_list.append(FEATURE_TO_CODE[str(Feature.LOCAL_MODE)]) + + # Construct the extra info to track platform and environment usage metadata + extra = ( + f"{func_name}" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-env={PYTHON_VERSION}" + f"&x-sys={OS_NAME_VERSION}" + f"&x-platform={studio_app_type}" + ) + + # Add endpoint ARN to the extra info if available + if self.sagemaker_session and self.sagemaker_session.endpoint_arn: + extra += f"&x-endpointArn={self.sagemaker_session.endpoint_arn}" + + start_timer = perf_counter() + try: + # Call the original function + response = func(self, *args, **kwargs) + stop_timer = perf_counter() + elapsed = stop_timer - start_timer + extra += f"&x-latency={round(elapsed, 2)}" + if not telemetry_opt_out_flag: + _send_telemetry_request( + STATUS_TO_CODE[str(Status.SUCCESS)], + feature_list, + self.sagemaker_session, + None, + None, + extra, + ) + except Exception as e: # pylint: disable=W0703 + stop_timer = perf_counter() + elapsed = stop_timer - start_timer + extra += f"&x-latency={round(elapsed, 2)}" + if not telemetry_opt_out_flag: + _send_telemetry_request( + STATUS_TO_CODE[str(Status.FAILURE)], + feature_list, + self.sagemaker_session, + str(e), + e.__class__.__name__, + extra, + ) + caught_ex = e + finally: + if caught_ex: + raise caught_ex + return response # pylint: disable=W0150 + + return wrapper + + return decorator + + +from sagemaker.session import Session # noqa: E402 pylint: disable=C0413 + + +def _send_telemetry_request( + status: int, + feature_list: List[int], + session: Session, + failure_reason: str = None, + failure_type: str = None, + extra_info: str = None, +) -> None: + """Make GET request to an empty object in S3 bucket""" + try: + accountId = _get_accountId(session) + region = _get_region_or_default(session) + url = _construct_url( + accountId, + region, + str(status), + str( + ",".join(map(str, feature_list)) + ), # Remove brackets and quotes to cut down on length + failure_reason, + failure_type, + extra_info, + ) + # Send the telemetry request + logger.debug("Sending telemetry request to [%s]", url) + _requests_helper(url, 2) + logger.debug("SageMaker Python SDK telemetry successfully emitted!") + except Exception: # pylint: disable=W0703 + logger.debug("SageMaker Python SDK telemetry not emitted!!") + + +def _construct_url( + accountId: str, + region: str, + status: str, + feature: str, + failure_reason: str, + failure_type: str, + extra_info: str, +) -> str: + """Construct the URL for the telemetry request""" + + base_url = ( + f"https://sm-pysdk-t-{region}.s3.{region}.amazonaws.com/telemetry?" + f"x-accountId={accountId}" + f"&x-status={status}" + f"&x-feature={feature}" + ) + logger.debug("Failure reason: %s", failure_reason) + if failure_reason: + base_url += f"&x-failureReason={failure_reason}" + base_url += f"&x-failureType={failure_type}" + if extra_info: + base_url += f"&x-extra={extra_info}" + return base_url + + +import requests # noqa: E402 pylint: disable=C0413,C0411 + + +def _requests_helper(url, timeout): + """Make a GET request to the given URL""" + + response = None + try: + response = requests.get(url, timeout) + except requests.exceptions.RequestException as e: + logger.exception("Request exception: %s", str(e)) + return response + + +def _get_accountId(session): + """Return the account ID from the boto session""" + + try: + sts = session.boto_session.client("sts") + return sts.get_caller_identity()["Account"] + except Exception: # pylint: disable=W0703 + return None + + +def _get_region_or_default(session): + """Return the region name from the boto session or default to us-west-2""" + + try: + return session.boto_session.region_name + except Exception: # pylint: disable=W0703 + return DEFAULT_AWS_REGION diff --git a/tests/integ/sagemaker/conftest.py b/tests/integ/sagemaker/conftest.py index 043b0c703e..46539e6de3 100644 --- a/tests/integ/sagemaker/conftest.py +++ b/tests/integ/sagemaker/conftest.py @@ -102,7 +102,9 @@ "channels:\n" " - defaults\n" "dependencies:\n" - " - scipy=1.10.1\n" + " - requests=2.32.3\n" + " - charset-normalizer=3.3.2\n" + " - scipy=1.13.1\n" " - pip:\n" " - /sagemaker-{sagemaker_version}.tar.gz\n" "prefix: /opt/conda/bin/conda\n" diff --git a/tests/unit/sagemaker/local/test_local_session.py b/tests/unit/sagemaker/local/test_local_session.py index ceae674704..ce8fd19b5c 100644 --- a/tests/unit/sagemaker/local/test_local_session.py +++ b/tests/unit/sagemaker/local/test_local_session.py @@ -47,7 +47,8 @@ @patch("sagemaker.local.image._SageMakerContainer.process") @patch("sagemaker.local.local_session.LocalSession") -def test_create_processing_job(process, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_processing_job(process, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -142,7 +143,8 @@ def test_create_processing_job(process, LocalSession): @patch("sagemaker.local.image._SageMakerContainer.process") @patch("sagemaker.local.local_session.LocalSession") -def test_create_processing_job_not_fully_replicated(process, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_processing_job_not_fully_replicated(process, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -197,7 +199,8 @@ def test_create_processing_job_not_fully_replicated(process, LocalSession): @patch("sagemaker.local.image._SageMakerContainer.process") @patch("sagemaker.local.local_session.LocalSession") -def test_create_processing_job_invalid_upload_mode(process, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_processing_job_invalid_upload_mode(process, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -252,7 +255,8 @@ def test_create_processing_job_invalid_upload_mode(process, LocalSession): @patch("sagemaker.local.image._SageMakerContainer.process") @patch("sagemaker.local.local_session.LocalSession") -def test_create_processing_job_invalid_processing_input(process, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_processing_job_invalid_processing_input(process, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -302,7 +306,8 @@ def test_create_processing_job_invalid_processing_input(process, LocalSession): @patch("sagemaker.local.image._SageMakerContainer.process") @patch("sagemaker.local.local_session.LocalSession") -def test_create_processing_job_invalid_processing_output(process, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_processing_job_invalid_processing_output(process, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -360,7 +365,8 @@ def test_describe_invalid_processing_job(*args): @patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") @patch("sagemaker.local.local_session.LocalSession") -def test_create_training_job(train, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_training_job(train, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -427,7 +433,8 @@ def test_describe_invalid_training_job(*args): @patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") @patch("sagemaker.local.local_session.LocalSession") -def test_create_training_job_invalid_data_source(train, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_training_job_invalid_data_source(train, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -466,7 +473,8 @@ def test_create_training_job_invalid_data_source(train, LocalSession): @patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model") @patch("sagemaker.local.local_session.LocalSession") -def test_create_training_job_not_fully_replicated(train, LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_training_job_not_fully_replicated(train, LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() instance_count = 2 @@ -503,7 +511,8 @@ def test_create_training_job_not_fully_replicated(train, LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_create_model(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_model(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER) @@ -512,7 +521,8 @@ def test_create_model(LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_delete_model(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_delete_model(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_model(MODEL_NAME, PRIMARY_CONTAINER) @@ -523,7 +533,8 @@ def test_delete_model(LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_describe_model(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_describe_model(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() with pytest.raises(ClientError): @@ -536,9 +547,10 @@ def test_describe_model(LocalSession): assert response["PrimaryContainer"]["ModelDataUrl"] == "/some/model/path" +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) @patch("sagemaker.local.local_session._LocalTransformJob") @patch("sagemaker.local.local_session.LocalSession") -def test_create_transform_job(LocalSession, _LocalTransformJob): +def test_create_transform_job(LocalSession, _LocalTransformJob, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_transform_job("transform-job", "some-model", None, None, None) @@ -572,7 +584,8 @@ def test_logs_for_processing_job(process, LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_describe_endpoint_config(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_describe_endpoint_config(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() # No Endpoint Config Created @@ -588,7 +601,8 @@ def test_describe_endpoint_config(LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_create_endpoint_config(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_create_endpoint_config(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS) @@ -598,7 +612,8 @@ def test_create_endpoint_config(LocalSession): @patch("sagemaker.local.local_session.LocalSession") -def test_delete_endpoint_config(LocalSession): +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) +def test_delete_endpoint_config(LocalSession, mock_telemetry): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() local_sagemaker_client.create_endpoint_config(ENDPOINT_CONFIG_NAME, PRODUCTION_VARIANTS) @@ -613,12 +628,15 @@ def test_delete_endpoint_config(LocalSession): ) +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) @patch("sagemaker.local.image._SageMakerContainer.serve") @patch("sagemaker.local.local_session.LocalSession") @patch("urllib3.PoolManager.request") @patch("sagemaker.local.local_session.LocalSagemakerClient.describe_endpoint_config") @patch("sagemaker.local.local_session.LocalSagemakerClient.describe_model") -def test_describe_endpoint(describe_model, describe_endpoint_config, request, *args): +def test_describe_endpoint( + describe_model, describe_endpoint_config, request, mock_telemetry, *args +): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() request.return_value = OK_RESPONSE @@ -658,12 +676,13 @@ def test_describe_endpoint(describe_model, describe_endpoint_config, request, *a assert response["EndpointName"] == "test-endpoint" +@patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config", side_effect=None) @patch("sagemaker.local.image._SageMakerContainer.serve") @patch("sagemaker.local.local_session.LocalSession") @patch("urllib3.PoolManager.request") @patch("sagemaker.local.local_session.LocalSagemakerClient.describe_endpoint_config") @patch("sagemaker.local.local_session.LocalSagemakerClient.describe_model") -def test_create_endpoint(describe_model, describe_endpoint_config, request, *args): +def test_create_endpoint(describe_model, describe_endpoint_config, request, mock_telemetry, *args): local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient() request.return_value = OK_RESPONSE diff --git a/tests/unit/sagemaker/telemetry/test_telemetry_logging.py b/tests/unit/sagemaker/telemetry/test_telemetry_logging.py new file mode 100644 index 0000000000..af782c8b51 --- /dev/null +++ b/tests/unit/sagemaker/telemetry/test_telemetry_logging.py @@ -0,0 +1,284 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import unittest +import pytest +import requests +from unittest.mock import Mock, patch, MagicMock +import sagemaker +from sagemaker.telemetry.constants import Feature +from sagemaker.telemetry.telemetry_logging import ( + _send_telemetry_request, + _telemetry_emitter, + _construct_url, + _get_accountId, + _requests_helper, + _get_region_or_default, + OS_NAME_VERSION, + PYTHON_VERSION, +) +from sagemaker.user_agent import SDK_VERSION, process_studio_metadata_file +from sagemaker.serve.utils.exceptions import ModelBuilderException, LocalModelOutOfMemoryException + +MOCK_SESSION = Mock() +MOCK_EXCEPTION = LocalModelOutOfMemoryException("mock raise ex") +MOCK_FEATURE = Feature.SDK_DEFAULTS +MOCK_FUNC_NAME = "Mock.local_session.create_model" +MOCK_ENDPOINT_ARN = "arn:aws:sagemaker:us-west-2:123456789012:endpoint/test" + + +class LocalSagemakerClientMock: + def __init__(self): + self.sagemaker_session = MOCK_SESSION + + @_telemetry_emitter(MOCK_FEATURE, MOCK_FUNC_NAME) + def mock_create_model(self, mock_exception_func=None): + if mock_exception_func: + mock_exception_func() + + +class TestTelemetryLogging(unittest.TestCase): + @patch("sagemaker.telemetry.telemetry_logging._requests_helper") + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + def test_log_sucessfully(self, mock_get_accountId, mock_request_helper): + """Test to check if the telemetry logging is successful""" + MOCK_SESSION.boto_session.region_name = "us-west-2" + mock_get_accountId.return_value = "testAccountId" + _send_telemetry_request("someStatus", "1", MOCK_SESSION) + mock_request_helper.assert_called_with( + "https://sm-pysdk-t-us-west-2.s3.us-west-2.amazonaws.com/" + "telemetry?x-accountId=testAccountId&x-status=someStatus&x-feature=1", + 2, + ) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + def test_log_handle_exception(self, mock_get_accountId): + """Test to check if the exception is handled while logging telemetry""" + mock_get_accountId.side_effect = Exception("Internal error") + _send_telemetry_request("someStatus", "1", MOCK_SESSION) + self.assertRaises(Exception) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + @patch("sagemaker.telemetry.telemetry_logging._get_region_or_default") + def test_send_telemetry_request_success(self, mock_get_region, mock_get_accountId): + """Test to check the _send_telemetry_request function with success status""" + mock_get_accountId.return_value = "testAccountId" + mock_get_region.return_value = "us-west-2" + + with patch( + "sagemaker.telemetry.telemetry_logging._requests_helper" + ) as mock_requests_helper: + mock_requests_helper.return_value = None + _send_telemetry_request(1, [1, 2], MagicMock(), None, None, "extra_info") + mock_requests_helper.assert_called_with( + "https://sm-pysdk-t-us-west-2.s3.us-west-2.amazonaws.com/" + "telemetry?x-accountId=testAccountId&x-status=1&x-feature=1,2&x-extra=extra_info", + 2, + ) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + @patch("sagemaker.telemetry.telemetry_logging._get_region_or_default") + def test_send_telemetry_request_failure(self, mock_get_region, mock_get_accountId): + """Test to check the _send_telemetry_request function with failure status""" + mock_get_accountId.return_value = "testAccountId" + mock_get_region.return_value = "us-west-2" + + with patch( + "sagemaker.telemetry.telemetry_logging._requests_helper" + ) as mock_requests_helper: + mock_requests_helper.return_value = None + _send_telemetry_request( + 0, [1, 2], MagicMock(), "failure_reason", "failure_type", "extra_info" + ) + mock_requests_helper.assert_called_with( + "https://sm-pysdk-t-us-west-2.s3.us-west-2.amazonaws.com/" + "telemetry?x-accountId=testAccountId&x-status=0&x-feature=1,2" + "&x-failureReason=failure_reason&x-failureType=failure_type&x-extra=extra_info", + 2, + ) + + @patch("sagemaker.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_decorator_no_call_when_disabled( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test to check if the _telemetry_emitter decorator is not called when telemetry is disabled""" + mock_resolve_config.return_value = True + + assert not mock_send_telemetry_request.called + + @patch("sagemaker.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_decorator_success( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test to verify the _telemetry_emitter decorator with success status""" + mock_resolve_config.return_value = False + mock_local_client = LocalSagemakerClientMock() + mock_local_client.sagemaker_session.endpoint_arn = MOCK_ENDPOINT_ARN + mock_local_client.mock_create_model() + app_type = process_studio_metadata_file() + + args = mock_send_telemetry_request.call_args.args + latency = str(args[5]).split("latency=")[1] + expected_extra_str = ( + f"{MOCK_FUNC_NAME}" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-env={PYTHON_VERSION}" + f"&x-sys={OS_NAME_VERSION}" + f"&x-platform={app_type}" + f"&x-endpointArn={MOCK_ENDPOINT_ARN}" + f"&x-latency={latency}" + ) + + mock_send_telemetry_request.assert_called_once_with( + 1, [1, 2], MOCK_SESSION, None, None, expected_extra_str + ) + + @patch("sagemaker.telemetry.telemetry_logging._send_telemetry_request") + @patch("sagemaker.telemetry.telemetry_logging.resolve_value_from_config") + def test_telemetry_emitter_decorator_handle_exception_success( + self, mock_resolve_config, mock_send_telemetry_request + ): + """Test to verify the _telemetry_emitter decorator when function emits exception""" + mock_resolve_config.return_value = False + mock_local_client = LocalSagemakerClientMock() + mock_local_client.sagemaker_session.endpoint_arn = MOCK_ENDPOINT_ARN + app_type = process_studio_metadata_file() + + mock_exception = Mock() + mock_exception_obj = MOCK_EXCEPTION + mock_exception.side_effect = mock_exception_obj + + with self.assertRaises(ModelBuilderException) as _: + mock_local_client.mock_create_model(mock_exception) + + args = mock_send_telemetry_request.call_args.args + latency = str(args[5]).split("latency=")[1] + expected_extra_str = ( + f"{MOCK_FUNC_NAME}" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-env={PYTHON_VERSION}" + f"&x-sys={OS_NAME_VERSION}" + f"&x-platform={app_type}" + f"&x-endpointArn={MOCK_ENDPOINT_ARN}" + f"&x-latency={latency}" + ) + + mock_send_telemetry_request.assert_called_once_with( + 0, + [1, 2], + MOCK_SESSION, + str(mock_exception_obj), + mock_exception_obj.__class__.__name__, + expected_extra_str, + ) + + def test_construct_url_with_failure_reason_and_extra_info(self): + """Test to verify the _construct_url function with failure reason and extra info""" + mock_accountId = "testAccountId" + mock_status = 0 + mock_feature = "1,2" + mock_failure_reason = str(MOCK_EXCEPTION) + mock_failure_type = MOCK_EXCEPTION.__class__.__name__ + mock_extra_info = "mock_extra_info" + mock_region = "us-west-2" + + resulted_url = _construct_url( + accountId=mock_accountId, + region=mock_region, + status=mock_status, + feature=mock_feature, + failure_reason=mock_failure_reason, + failure_type=mock_failure_type, + extra_info=mock_extra_info, + ) + + expected_base_url = ( + f"https://sm-pysdk-t-{mock_region}.s3.{mock_region}.amazonaws.com/telemetry?" + f"x-accountId={mock_accountId}" + f"&x-status={mock_status}" + f"&x-feature={mock_feature}" + f"&x-failureReason={mock_failure_reason}" + f"&x-failureType={mock_failure_type}" + f"&x-extra={mock_extra_info}" + ) + self.assertEqual(resulted_url, expected_base_url) + + @patch("sagemaker.telemetry.telemetry_logging.requests.get") + def test_requests_helper_success(self, mock_requests_get): + """Test to verify the _requests_helper function with success status""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_requests_get.return_value = mock_response + url = "https://example.com" + timeout = 10 + + response = _requests_helper(url, timeout) + + mock_requests_get.assert_called_once_with(url, timeout) + self.assertEqual(response, mock_response) + + @patch("sagemaker.telemetry.telemetry_logging.requests.get") + def test_requests_helper_exception(self, mock_requests_get): + """Test to verify the _requests_helper function with exception""" + mock_requests_get.side_effect = requests.exceptions.RequestException("Error making request") + url = "https://example.com" + timeout = 10 + + response = _requests_helper(url, timeout) + + mock_requests_get.assert_called_once_with(url, timeout) + self.assertIsNone(response) + + def test_get_accountId_success(self): + """Test to verify the _get_accountId function with success status""" + boto_mock = MagicMock(name="boto_session") + boto_mock.client("sts").get_caller_identity.return_value = {"Account": "testAccountId"} + session = sagemaker.Session(boto_session=boto_mock) + account_id = _get_accountId(session) + + self.assertEqual(account_id, "testAccountId") + + def test_get_accountId_exception(self): + """Test to verify the _get_accountId function with exception""" + sts_client_mock = MagicMock() + sts_client_mock.side_effect = Exception("Error creating STS client") + boto_mock = MagicMock(name="boto_session") + boto_mock.client("sts").get_caller_identity.return_value = sts_client_mock + session = sagemaker.Session(boto_session=boto_mock) + + with pytest.raises(Exception) as exception: + account_id = _get_accountId(session) + assert account_id is None + assert "Error creating STS client" in str(exception) + + def test_get_region_or_default_success(self): + """Test to verify the _get_region_or_default function with success status""" + mock_session = MagicMock() + mock_session.boto_session = MagicMock(region_name="us-east-1") + + region = _get_region_or_default(mock_session) + + assert region == "us-east-1" + + def test_get_region_or_default_exception(self): + """Test to verify the _get_region_or_default function with exception""" + mock_session = MagicMock() + mock_session.boto_session = MagicMock() + mock_session.boto_session.region_name.side_effect = Exception("Error creating boto session") + + with pytest.raises(Exception) as exception: + region = _get_region_or_default(mock_session) + assert region == "us-west-2" + assert "Error creating boto session" in str(exception)