File tree 2 files changed +30
-1
lines changed
tests/unit/sagemaker/telemetry
2 files changed +30
-1
lines changed Original file line number Diff line number Diff line change 20
20
import functools
21
21
import requests
22
22
23
+ import boto3
23
24
from sagemaker .session import Session
24
25
from sagemaker .utils import resolve_value_from_config
25
26
from sagemaker .config .config_schema import TELEMETRY_OPT_OUT_PATH
@@ -71,7 +72,9 @@ def wrapper(*args, **kwargs):
71
72
sagemaker_session = args [0 ].sagemaker_session
72
73
elif feature == Feature .REMOTE_FUNCTION :
73
74
# Get the sagemaker_session from the function keyword arguments for remote function
74
- sagemaker_session = kwargs .get ("sagemaker_session" , Session ())
75
+ sagemaker_session = kwargs .get (
76
+ "sagemaker_session" , _get_default_sagemaker_session ()
77
+ )
75
78
76
79
if sagemaker_session :
77
80
logger .debug ("sagemaker_session found, preparing to emit telemetry..." )
@@ -243,3 +246,11 @@ def _get_region_or_default(session):
243
246
return session .boto_session .region_name
244
247
except Exception : # pylint: disable=W0703
245
248
return DEFAULT_AWS_REGION
249
+
250
+
251
+ def _get_default_sagemaker_session ():
252
+ """Return the default sagemaker session"""
253
+ boto_session = boto3 .Session (region_name = DEFAULT_AWS_REGION )
254
+ sagemaker_session = Session (boto_session = boto_session )
255
+
256
+ return sagemaker_session
Original file line number Diff line number Diff line change 15
15
import pytest
16
16
import requests
17
17
from unittest .mock import Mock , patch , MagicMock
18
+ import boto3
18
19
import sagemaker
19
20
from sagemaker .telemetry .constants import Feature
20
21
from sagemaker .telemetry .telemetry_logging import (
24
25
_get_accountId ,
25
26
_requests_helper ,
26
27
_get_region_or_default ,
28
+ _get_default_sagemaker_session ,
27
29
OS_NAME_VERSION ,
28
30
PYTHON_VERSION ,
29
31
)
@@ -282,3 +284,19 @@ def test_get_region_or_default_exception(self):
282
284
region = _get_region_or_default (mock_session )
283
285
assert region == "us-west-2"
284
286
assert "Error creating boto session" in str (exception )
287
+
288
+ @patch .object (boto3 .Session , "region_name" , "us-west-2" )
289
+ def test_get_default_sagemaker_session (self ):
290
+ sagemaker_session = _get_default_sagemaker_session ()
291
+
292
+ assert sagemaker_session is sagemaker .Session
293
+ assert sagemaker_session .boto_session .region_name == "us-west-2"
294
+
295
+ @patch .object (boto3 .Session , "region_name" , None )
296
+ def test_get_default_sagemaker_session_with_no_region (self ):
297
+ with self .assertRaises (ValueError ) as context :
298
+ _get_default_sagemaker_session ()
299
+
300
+ assert "Must setup local AWS configuration with a region supported by SageMaker." in str (
301
+ context .exception
302
+ )
You can’t perform that action at this time.
0 commit comments