Skip to content

Commit a64dace

Browse files
committed
Fix default session issue
1 parent 0453c77 commit a64dace

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

src/sagemaker/telemetry/telemetry_logging.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import functools
2121
import requests
2222

23+
import boto3
2324
from sagemaker.session import Session
2425
from sagemaker.utils import resolve_value_from_config
2526
from sagemaker.config.config_schema import TELEMETRY_OPT_OUT_PATH
@@ -71,7 +72,9 @@ def wrapper(*args, **kwargs):
7172
sagemaker_session = args[0].sagemaker_session
7273
elif feature == Feature.REMOTE_FUNCTION:
7374
# Get the sagemaker_session from the function keyword arguments for remote function
74-
sagemaker_session = kwargs.get("sagemaker_session", Session())
75+
sagemaker_session = kwargs.get(
76+
"sagemaker_session", _get_default_sagemaker_session()
77+
)
7578

7679
if sagemaker_session:
7780
logger.debug("sagemaker_session found, preparing to emit telemetry...")
@@ -243,3 +246,11 @@ def _get_region_or_default(session):
243246
return session.boto_session.region_name
244247
except Exception: # pylint: disable=W0703
245248
return DEFAULT_AWS_REGION
249+
250+
251+
def _get_default_sagemaker_session():
252+
"""Return the default sagemaker session"""
253+
boto_session = boto3.Session(region_name=DEFAULT_AWS_REGION)
254+
sagemaker_session = Session(boto_session=boto_session)
255+
256+
return sagemaker_session

tests/unit/sagemaker/telemetry/test_telemetry_logging.py

+18
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pytest
1616
import requests
1717
from unittest.mock import Mock, patch, MagicMock
18+
import boto3
1819
import sagemaker
1920
from sagemaker.telemetry.constants import Feature
2021
from sagemaker.telemetry.telemetry_logging import (
@@ -24,6 +25,7 @@
2425
_get_accountId,
2526
_requests_helper,
2627
_get_region_or_default,
28+
_get_default_sagemaker_session,
2729
OS_NAME_VERSION,
2830
PYTHON_VERSION,
2931
)
@@ -282,3 +284,19 @@ def test_get_region_or_default_exception(self):
282284
region = _get_region_or_default(mock_session)
283285
assert region == "us-west-2"
284286
assert "Error creating boto session" in str(exception)
287+
288+
@patch.object(boto3.Session, "region_name", "us-west-2")
289+
def test_get_default_sagemaker_session(self):
290+
sagemaker_session = _get_default_sagemaker_session()
291+
292+
assert sagemaker_session is sagemaker.Session
293+
assert sagemaker_session.boto_session.region_name == "us-west-2"
294+
295+
@patch.object(boto3.Session, "region_name", None)
296+
def test_get_default_sagemaker_session_with_no_region(self):
297+
with self.assertRaises(ValueError) as context:
298+
_get_default_sagemaker_session()
299+
300+
assert "Must setup local AWS configuration with a region supported by SageMaker." in str(
301+
context.exception
302+
)

0 commit comments

Comments
 (0)