From 0453c77bbc8b87218fa808c24e8fc90fffe8d019 Mon Sep 17 00:00:00 2001 From: knikure Date: Wed, 12 Jun 2024 21:09:06 +0000 Subject: [PATCH 1/3] change: Enhance telemetry logging module and feature coverage --- src/sagemaker/remote_function/client.py | 3 + src/sagemaker/telemetry/constants.py | 1 + src/sagemaker/telemetry/telemetry_logging.py | 166 ++++++++++-------- .../sagemaker/remote_function/test_client.py | 7 +- 4 files changed, 98 insertions(+), 79 deletions(-) diff --git a/src/sagemaker/remote_function/client.py b/src/sagemaker/remote_function/client.py index 0dc69d8647..53a116e4ef 100644 --- a/src/sagemaker/remote_function/client.py +++ b/src/sagemaker/remote_function/client.py @@ -40,6 +40,8 @@ from sagemaker.utils import name_from_base, base_from_name from sagemaker.remote_function.spark_config import SparkConfig from sagemaker.remote_function.custom_file_filter import CustomFileFilter +from sagemaker.telemetry.telemetry_logging import _telemetry_emitter +from sagemaker.telemetry.constants import Feature _API_CALL_LIMIT = { "SubmittingIntervalInSecs": 1, @@ -57,6 +59,7 @@ logger = logging_config.get_logger() +@_telemetry_emitter(feature=Feature.REMOTE_FUNCTION, func_name="remote_function.remote") def remote( _func=None, *, diff --git a/src/sagemaker/telemetry/constants.py b/src/sagemaker/telemetry/constants.py index f891c226d8..332d706351 100644 --- a/src/sagemaker/telemetry/constants.py +++ b/src/sagemaker/telemetry/constants.py @@ -24,6 +24,7 @@ class Feature(Enum): SDK_DEFAULTS = 1 LOCAL_MODE = 2 + REMOTE_FUNCTION = 3 def __str__(self): # pylint: disable=E0307 """Return the feature name.""" diff --git a/src/sagemaker/telemetry/telemetry_logging.py b/src/sagemaker/telemetry/telemetry_logging.py index 59ad58d16a..f56eaec0e0 100644 --- a/src/sagemaker/telemetry/telemetry_logging.py +++ b/src/sagemaker/telemetry/telemetry_logging.py @@ -17,7 +17,10 @@ import sys from time import perf_counter from typing import List +import functools +import requests +from sagemaker.session import Session from sagemaker.utils import resolve_value_from_config from sagemaker.config.config_schema import TELEMETRY_OPT_OUT_PATH from sagemaker.telemetry.constants import ( @@ -47,6 +50,7 @@ FEATURE_TO_CODE = { str(Feature.SDK_DEFAULTS): 1, str(Feature.LOCAL_MODE): 2, + str(Feature.REMOTE_FUNCTION): 3, } STATUS_TO_CODE = { @@ -59,86 +63,101 @@ def _telemetry_emitter(feature: str, func_name: str): """Decorator to emit telemetry logs for SageMaker Python SDK functions""" def decorator(func): - def wrapper(self, *args, **kwargs): - logger.info(TELEMETRY_OPT_OUT_MESSAGING) - response = None - caught_ex = None - studio_app_type = process_studio_metadata_file() - - # Check if telemetry is opted out - telemetry_opt_out_flag = resolve_value_from_config( - direct_input=None, - config_path=TELEMETRY_OPT_OUT_PATH, - default_value=False, - sagemaker_session=self.sagemaker_session, - ) - logger.debug("TelemetryOptOut flag is set to: %s", telemetry_opt_out_flag) - - # Construct the feature list to track feature combinations - feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]] - if self.sagemaker_session: - if self.sagemaker_session.sagemaker_config and feature != Feature.SDK_DEFAULTS: + @functools.wraps(func) + def wrapper(*args, **kwargs): + sagemaker_session = None + if len(args) > 0 and hasattr(args[0], "sagemaker_session"): + # Get the sagemaker_session from the instance method args + sagemaker_session = args[0].sagemaker_session + elif feature == Feature.REMOTE_FUNCTION: + # Get the sagemaker_session from the function keyword arguments for remote function + sagemaker_session = kwargs.get("sagemaker_session", Session()) + + if sagemaker_session: + logger.debug("sagemaker_session found, preparing to emit telemetry...") + logger.info(TELEMETRY_OPT_OUT_MESSAGING) + response = None + caught_ex = None + studio_app_type = process_studio_metadata_file() + + # Check if telemetry is opted out + telemetry_opt_out_flag = resolve_value_from_config( + direct_input=None, + config_path=TELEMETRY_OPT_OUT_PATH, + default_value=False, + sagemaker_session=sagemaker_session, + ) + logger.debug("TelemetryOptOut flag is set to: %s", telemetry_opt_out_flag) + + # Construct the feature list to track feature combinations + feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]] + + if sagemaker_session.sagemaker_config and feature != Feature.SDK_DEFAULTS: feature_list.append(FEATURE_TO_CODE[str(Feature.SDK_DEFAULTS)]) - if self.sagemaker_session.local_mode and feature != Feature.LOCAL_MODE: + if sagemaker_session.local_mode and feature != Feature.LOCAL_MODE: feature_list.append(FEATURE_TO_CODE[str(Feature.LOCAL_MODE)]) - # Construct the extra info to track platform and environment usage metadata - extra = ( - f"{func_name}" - f"&x-sdkVersion={SDK_VERSION}" - f"&x-env={PYTHON_VERSION}" - f"&x-sys={OS_NAME_VERSION}" - f"&x-platform={studio_app_type}" - ) - - # Add endpoint ARN to the extra info if available - if self.sagemaker_session and self.sagemaker_session.endpoint_arn: - extra += f"&x-endpointArn={self.sagemaker_session.endpoint_arn}" - - start_timer = perf_counter() - try: - # Call the original function - response = func(self, *args, **kwargs) - stop_timer = perf_counter() - elapsed = stop_timer - start_timer - extra += f"&x-latency={round(elapsed, 2)}" - if not telemetry_opt_out_flag: - _send_telemetry_request( - STATUS_TO_CODE[str(Status.SUCCESS)], - feature_list, - self.sagemaker_session, - None, - None, - extra, - ) - except Exception as e: # pylint: disable=W0703 - stop_timer = perf_counter() - elapsed = stop_timer - start_timer - extra += f"&x-latency={round(elapsed, 2)}" - if not telemetry_opt_out_flag: - _send_telemetry_request( - STATUS_TO_CODE[str(Status.FAILURE)], - feature_list, - self.sagemaker_session, - str(e), - e.__class__.__name__, - extra, - ) - caught_ex = e - finally: - if caught_ex: - raise caught_ex - return response # pylint: disable=W0150 + # Construct the extra info to track platform and environment usage metadata + extra = ( + f"{func_name}" + f"&x-sdkVersion={SDK_VERSION}" + f"&x-env={PYTHON_VERSION}" + f"&x-sys={OS_NAME_VERSION}" + f"&x-platform={studio_app_type}" + ) + + # Add endpoint ARN to the extra info if available + if sagemaker_session.endpoint_arn: + extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}" + + start_timer = perf_counter() + try: + # Call the original function + response = func(*args, **kwargs) + stop_timer = perf_counter() + elapsed = stop_timer - start_timer + extra += f"&x-latency={round(elapsed, 2)}" + if not telemetry_opt_out_flag: + _send_telemetry_request( + STATUS_TO_CODE[str(Status.SUCCESS)], + feature_list, + sagemaker_session, + None, + None, + extra, + ) + except Exception as e: # pylint: disable=W0703 + stop_timer = perf_counter() + elapsed = stop_timer - start_timer + extra += f"&x-latency={round(elapsed, 2)}" + if not telemetry_opt_out_flag: + _send_telemetry_request( + STATUS_TO_CODE[str(Status.FAILURE)], + feature_list, + sagemaker_session, + str(e), + e.__class__.__name__, + extra, + ) + caught_ex = e + finally: + if caught_ex: + raise caught_ex + return response # pylint: disable=W0150 + else: + logger.debug( + "Unable to send telemetry for function %s. " + "sagemaker_session is not provided or not valid.", + func_name, + ) + return func(*args, **kwargs) return wrapper return decorator -from sagemaker.session import Session # noqa: E402 pylint: disable=C0413 - - def _send_telemetry_request( status: int, feature_list: List[int], @@ -165,9 +184,9 @@ def _send_telemetry_request( # Send the telemetry request logger.debug("Sending telemetry request to [%s]", url) _requests_helper(url, 2) - logger.debug("SageMaker Python SDK telemetry successfully emitted!") + logger.debug("SageMaker Python SDK telemetry successfully emitted.") except Exception: # pylint: disable=W0703 - logger.debug("SageMaker Python SDK telemetry not emitted!!") + logger.debug("SageMaker Python SDK telemetry not emitted!") def _construct_url( @@ -196,9 +215,6 @@ def _construct_url( return base_url -import requests # noqa: E402 pylint: disable=C0413,C0411 - - def _requests_helper(url, timeout): """Make a GET request to the given URL""" diff --git a/tests/unit/sagemaker/remote_function/test_client.py b/tests/unit/sagemaker/remote_function/test_client.py index 20d05a933e..1d752f89ed 100644 --- a/tests/unit/sagemaker/remote_function/test_client.py +++ b/tests/unit/sagemaker/remote_function/test_client.py @@ -15,6 +15,7 @@ import os import threading import time +import inspect import pytest from mock import MagicMock, patch, Mock, ANY, call @@ -1498,7 +1499,6 @@ def test_consistency_between_remote_and_step_decorator(): from sagemaker.workflow.function_step import step remote_args_to_ignore = [ - "_remote", "include_local_workdir", "custom_file_filter", "s3_kms_key", @@ -1508,7 +1508,7 @@ def test_consistency_between_remote_and_step_decorator(): step_args_to_ignore = ["_step", "name", "display_name", "description", "retry_policies"] - remote_decorator_args = remote.__code__.co_varnames + remote_decorator_args = inspect.signature(remote).parameters.keys() common_remote_decorator_args = set(remote_args_to_ignore) ^ set(remote_decorator_args) step_decorator_args = step.__code__.co_varnames @@ -1522,8 +1522,7 @@ def test_consistency_between_remote_and_executor(): executor_arg_list.remove("self") executor_arg_list.remove("max_parallel_jobs") - remote_args_list = list(remote.__code__.co_varnames) - remote_args_list.remove("_remote") + remote_args_list = list(inspect.signature(remote).parameters.keys()) remote_args_list.remove("_func") assert executor_arg_list == remote_args_list From a64dacef6d377c2e1aaae5106e627b14399e56f0 Mon Sep 17 00:00:00 2001 From: knikure Date: Fri, 14 Jun 2024 16:37:24 +0000 Subject: [PATCH 2/3] Fix default session issue --- src/sagemaker/telemetry/telemetry_logging.py | 13 ++++++++++++- .../telemetry/test_telemetry_logging.py | 18 ++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/telemetry/telemetry_logging.py b/src/sagemaker/telemetry/telemetry_logging.py index f56eaec0e0..d2b91a321c 100644 --- a/src/sagemaker/telemetry/telemetry_logging.py +++ b/src/sagemaker/telemetry/telemetry_logging.py @@ -20,6 +20,7 @@ import functools import requests +import boto3 from sagemaker.session import Session from sagemaker.utils import resolve_value_from_config from sagemaker.config.config_schema import TELEMETRY_OPT_OUT_PATH @@ -71,7 +72,9 @@ def wrapper(*args, **kwargs): sagemaker_session = args[0].sagemaker_session elif feature == Feature.REMOTE_FUNCTION: # Get the sagemaker_session from the function keyword arguments for remote function - sagemaker_session = kwargs.get("sagemaker_session", Session()) + sagemaker_session = kwargs.get( + "sagemaker_session", _get_default_sagemaker_session() + ) if sagemaker_session: logger.debug("sagemaker_session found, preparing to emit telemetry...") @@ -243,3 +246,11 @@ def _get_region_or_default(session): return session.boto_session.region_name except Exception: # pylint: disable=W0703 return DEFAULT_AWS_REGION + + +def _get_default_sagemaker_session(): + """Return the default sagemaker session""" + boto_session = boto3.Session(region_name=DEFAULT_AWS_REGION) + sagemaker_session = Session(boto_session=boto_session) + + return sagemaker_session diff --git a/tests/unit/sagemaker/telemetry/test_telemetry_logging.py b/tests/unit/sagemaker/telemetry/test_telemetry_logging.py index af782c8b51..dfc9b7e7ce 100644 --- a/tests/unit/sagemaker/telemetry/test_telemetry_logging.py +++ b/tests/unit/sagemaker/telemetry/test_telemetry_logging.py @@ -15,6 +15,7 @@ import pytest import requests from unittest.mock import Mock, patch, MagicMock +import boto3 import sagemaker from sagemaker.telemetry.constants import Feature from sagemaker.telemetry.telemetry_logging import ( @@ -24,6 +25,7 @@ _get_accountId, _requests_helper, _get_region_or_default, + _get_default_sagemaker_session, OS_NAME_VERSION, PYTHON_VERSION, ) @@ -282,3 +284,19 @@ def test_get_region_or_default_exception(self): region = _get_region_or_default(mock_session) assert region == "us-west-2" assert "Error creating boto session" in str(exception) + + @patch.object(boto3.Session, "region_name", "us-west-2") + def test_get_default_sagemaker_session(self): + sagemaker_session = _get_default_sagemaker_session() + + assert sagemaker_session is sagemaker.Session + assert sagemaker_session.boto_session.region_name == "us-west-2" + + @patch.object(boto3.Session, "region_name", None) + def test_get_default_sagemaker_session_with_no_region(self): + with self.assertRaises(ValueError) as context: + _get_default_sagemaker_session() + + assert "Must setup local AWS configuration with a region supported by SageMaker." in str( + context.exception + ) From b33d78281a26e27780146c8813a2aaebcbcb8152 Mon Sep 17 00:00:00 2001 From: knikure Date: Fri, 14 Jun 2024 19:34:04 +0000 Subject: [PATCH 3/3] fix unit-tests --- tests/unit/sagemaker/telemetry/test_telemetry_logging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sagemaker/telemetry/test_telemetry_logging.py b/tests/unit/sagemaker/telemetry/test_telemetry_logging.py index dfc9b7e7ce..9107256b5b 100644 --- a/tests/unit/sagemaker/telemetry/test_telemetry_logging.py +++ b/tests/unit/sagemaker/telemetry/test_telemetry_logging.py @@ -289,7 +289,7 @@ def test_get_region_or_default_exception(self): def test_get_default_sagemaker_session(self): sagemaker_session = _get_default_sagemaker_session() - assert sagemaker_session is sagemaker.Session + assert isinstance(sagemaker_session, sagemaker.Session) is True assert sagemaker_session.boto_session.region_name == "us-west-2" @patch.object(boto3.Session, "region_name", None)