Skip to content

change: Enable telemetry logging for Remote function #4729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/sagemaker/remote_function/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from sagemaker.utils import name_from_base, base_from_name
from sagemaker.remote_function.spark_config import SparkConfig
from sagemaker.remote_function.custom_file_filter import CustomFileFilter
from sagemaker.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.telemetry.constants import Feature

_API_CALL_LIMIT = {
"SubmittingIntervalInSecs": 1,
Expand All @@ -57,6 +59,7 @@
logger = logging_config.get_logger()


@_telemetry_emitter(feature=Feature.REMOTE_FUNCTION, func_name="remote_function.remote")
def remote(
_func=None,
*,
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/telemetry/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class Feature(Enum):

SDK_DEFAULTS = 1
LOCAL_MODE = 2
REMOTE_FUNCTION = 3

def __str__(self): # pylint: disable=E0307
"""Return the feature name."""
Expand Down
177 changes: 102 additions & 75 deletions src/sagemaker/telemetry/telemetry_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import sys
from time import perf_counter
from typing import List
import functools
import requests

import boto3
from sagemaker.session import Session
from sagemaker.utils import resolve_value_from_config
from sagemaker.config.config_schema import TELEMETRY_OPT_OUT_PATH
from sagemaker.telemetry.constants import (
Expand Down Expand Up @@ -47,6 +51,7 @@
FEATURE_TO_CODE = {
str(Feature.SDK_DEFAULTS): 1,
str(Feature.LOCAL_MODE): 2,
str(Feature.REMOTE_FUNCTION): 3,
}

STATUS_TO_CODE = {
Expand All @@ -59,86 +64,103 @@ def _telemetry_emitter(feature: str, func_name: str):
"""Decorator to emit telemetry logs for SageMaker Python SDK functions"""

def decorator(func):
def wrapper(self, *args, **kwargs):
logger.info(TELEMETRY_OPT_OUT_MESSAGING)
response = None
caught_ex = None
studio_app_type = process_studio_metadata_file()

# Check if telemetry is opted out
telemetry_opt_out_flag = resolve_value_from_config(
direct_input=None,
config_path=TELEMETRY_OPT_OUT_PATH,
default_value=False,
sagemaker_session=self.sagemaker_session,
)
logger.debug("TelemetryOptOut flag is set to: %s", telemetry_opt_out_flag)

# Construct the feature list to track feature combinations
feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]]
if self.sagemaker_session:
if self.sagemaker_session.sagemaker_config and feature != Feature.SDK_DEFAULTS:
@functools.wraps(func)
def wrapper(*args, **kwargs):
sagemaker_session = None
if len(args) > 0 and hasattr(args[0], "sagemaker_session"):
# Get the sagemaker_session from the instance method args
sagemaker_session = args[0].sagemaker_session
elif feature == Feature.REMOTE_FUNCTION:
# Get the sagemaker_session from the function keyword arguments for remote function
sagemaker_session = kwargs.get(
"sagemaker_session", _get_default_sagemaker_session()
)

if sagemaker_session:
logger.debug("sagemaker_session found, preparing to emit telemetry...")
logger.info(TELEMETRY_OPT_OUT_MESSAGING)
response = None
caught_ex = None
studio_app_type = process_studio_metadata_file()

# Check if telemetry is opted out
telemetry_opt_out_flag = resolve_value_from_config(
direct_input=None,
config_path=TELEMETRY_OPT_OUT_PATH,
default_value=False,
sagemaker_session=sagemaker_session,
)
logger.debug("TelemetryOptOut flag is set to: %s", telemetry_opt_out_flag)

# Construct the feature list to track feature combinations
feature_list: List[int] = [FEATURE_TO_CODE[str(feature)]]

if sagemaker_session.sagemaker_config and feature != Feature.SDK_DEFAULTS:
feature_list.append(FEATURE_TO_CODE[str(Feature.SDK_DEFAULTS)])

if self.sagemaker_session.local_mode and feature != Feature.LOCAL_MODE:
if sagemaker_session.local_mode and feature != Feature.LOCAL_MODE:
feature_list.append(FEATURE_TO_CODE[str(Feature.LOCAL_MODE)])

# Construct the extra info to track platform and environment usage metadata
extra = (
f"{func_name}"
f"&x-sdkVersion={SDK_VERSION}"
f"&x-env={PYTHON_VERSION}"
f"&x-sys={OS_NAME_VERSION}"
f"&x-platform={studio_app_type}"
)

# Add endpoint ARN to the extra info if available
if self.sagemaker_session and self.sagemaker_session.endpoint_arn:
extra += f"&x-endpointArn={self.sagemaker_session.endpoint_arn}"

start_timer = perf_counter()
try:
# Call the original function
response = func(self, *args, **kwargs)
stop_timer = perf_counter()
elapsed = stop_timer - start_timer
extra += f"&x-latency={round(elapsed, 2)}"
if not telemetry_opt_out_flag:
_send_telemetry_request(
STATUS_TO_CODE[str(Status.SUCCESS)],
feature_list,
self.sagemaker_session,
None,
None,
extra,
)
except Exception as e: # pylint: disable=W0703
stop_timer = perf_counter()
elapsed = stop_timer - start_timer
extra += f"&x-latency={round(elapsed, 2)}"
if not telemetry_opt_out_flag:
_send_telemetry_request(
STATUS_TO_CODE[str(Status.FAILURE)],
feature_list,
self.sagemaker_session,
str(e),
e.__class__.__name__,
extra,
)
caught_ex = e
finally:
if caught_ex:
raise caught_ex
return response # pylint: disable=W0150
# Construct the extra info to track platform and environment usage metadata
extra = (
f"{func_name}"
f"&x-sdkVersion={SDK_VERSION}"
f"&x-env={PYTHON_VERSION}"
f"&x-sys={OS_NAME_VERSION}"
f"&x-platform={studio_app_type}"
)

# Add endpoint ARN to the extra info if available
if sagemaker_session.endpoint_arn:
extra += f"&x-endpointArn={sagemaker_session.endpoint_arn}"

start_timer = perf_counter()
try:
# Call the original function
response = func(*args, **kwargs)
stop_timer = perf_counter()
elapsed = stop_timer - start_timer
extra += f"&x-latency={round(elapsed, 2)}"
if not telemetry_opt_out_flag:
_send_telemetry_request(
STATUS_TO_CODE[str(Status.SUCCESS)],
feature_list,
sagemaker_session,
None,
None,
extra,
)
except Exception as e: # pylint: disable=W0703
stop_timer = perf_counter()
elapsed = stop_timer - start_timer
extra += f"&x-latency={round(elapsed, 2)}"
if not telemetry_opt_out_flag:
_send_telemetry_request(
STATUS_TO_CODE[str(Status.FAILURE)],
feature_list,
sagemaker_session,
str(e),
e.__class__.__name__,
extra,
)
caught_ex = e
finally:
if caught_ex:
raise caught_ex
return response # pylint: disable=W0150
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we can move the else branch on top to reduce to clean the code a little bit. For example:

if not sagemaker:
   logger.debug(...)
    return ...

logger.debug("sagemaker_session found, preparing to emit telemetry...")
...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack, Will do this in the follow-up PR for telemetry.

logger.debug(
"Unable to send telemetry for function %s. "
"sagemaker_session is not provided or not valid.",
func_name,
)
return func(*args, **kwargs)

return wrapper

return decorator


from sagemaker.session import Session # noqa: E402 pylint: disable=C0413


def _send_telemetry_request(
status: int,
feature_list: List[int],
Expand All @@ -165,9 +187,9 @@ def _send_telemetry_request(
# Send the telemetry request
logger.debug("Sending telemetry request to [%s]", url)
_requests_helper(url, 2)
logger.debug("SageMaker Python SDK telemetry successfully emitted!")
logger.debug("SageMaker Python SDK telemetry successfully emitted.")
except Exception: # pylint: disable=W0703
logger.debug("SageMaker Python SDK telemetry not emitted!!")
logger.debug("SageMaker Python SDK telemetry not emitted!")


def _construct_url(
Expand Down Expand Up @@ -196,9 +218,6 @@ def _construct_url(
return base_url


import requests # noqa: E402 pylint: disable=C0413,C0411


def _requests_helper(url, timeout):
"""Make a GET request to the given URL"""

Expand Down Expand Up @@ -227,3 +246,11 @@ def _get_region_or_default(session):
return session.boto_session.region_name
except Exception: # pylint: disable=W0703
return DEFAULT_AWS_REGION


def _get_default_sagemaker_session():
"""Return the default sagemaker session"""
boto_session = boto3.Session(region_name=DEFAULT_AWS_REGION)
sagemaker_session = Session(boto_session=boto_session)

return sagemaker_session
7 changes: 3 additions & 4 deletions tests/unit/sagemaker/remote_function/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import threading
import time
import inspect

import pytest
from mock import MagicMock, patch, Mock, ANY, call
Expand Down Expand Up @@ -1498,7 +1499,6 @@ def test_consistency_between_remote_and_step_decorator():
from sagemaker.workflow.function_step import step

remote_args_to_ignore = [
"_remote",
"include_local_workdir",
"custom_file_filter",
"s3_kms_key",
Expand All @@ -1508,7 +1508,7 @@ def test_consistency_between_remote_and_step_decorator():

step_args_to_ignore = ["_step", "name", "display_name", "description", "retry_policies"]

remote_decorator_args = remote.__code__.co_varnames
remote_decorator_args = inspect.signature(remote).parameters.keys()
common_remote_decorator_args = set(remote_args_to_ignore) ^ set(remote_decorator_args)

step_decorator_args = step.__code__.co_varnames
Expand All @@ -1522,8 +1522,7 @@ def test_consistency_between_remote_and_executor():
executor_arg_list.remove("self")
executor_arg_list.remove("max_parallel_jobs")

remote_args_list = list(remote.__code__.co_varnames)
remote_args_list.remove("_remote")
remote_args_list = list(inspect.signature(remote).parameters.keys())
remote_args_list.remove("_func")

assert executor_arg_list == remote_args_list
18 changes: 18 additions & 0 deletions tests/unit/sagemaker/telemetry/test_telemetry_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import pytest
import requests
from unittest.mock import Mock, patch, MagicMock
import boto3
import sagemaker
from sagemaker.telemetry.constants import Feature
from sagemaker.telemetry.telemetry_logging import (
Expand All @@ -24,6 +25,7 @@
_get_accountId,
_requests_helper,
_get_region_or_default,
_get_default_sagemaker_session,
OS_NAME_VERSION,
PYTHON_VERSION,
)
Expand Down Expand Up @@ -282,3 +284,19 @@ def test_get_region_or_default_exception(self):
region = _get_region_or_default(mock_session)
assert region == "us-west-2"
assert "Error creating boto session" in str(exception)

@patch.object(boto3.Session, "region_name", "us-west-2")
def test_get_default_sagemaker_session(self):
sagemaker_session = _get_default_sagemaker_session()

assert isinstance(sagemaker_session, sagemaker.Session) is True
assert sagemaker_session.boto_session.region_name == "us-west-2"

@patch.object(boto3.Session, "region_name", None)
def test_get_default_sagemaker_session_with_no_region(self):
with self.assertRaises(ValueError) as context:
_get_default_sagemaker_session()

assert "Must setup local AWS configuration with a region supported by SageMaker." in str(
context.exception
)