12
12
# language governing permissions and limitations under the License.
13
13
from __future__ import absolute_import
14
14
import unittest
15
- from unittest .mock import MagicMock , patch
16
- from sagemaker .serve .utils .telemetry_logger import _send_telemetry
15
+ from unittest .mock import Mock , patch
16
+ from sagemaker .serve import Mode , ModelServer
17
+ from sagemaker .serve .utils .telemetry_logger import (
18
+ _send_telemetry ,
19
+ _capture_telemetry ,
20
+ _construct_url ,
21
+ )
22
+ from sagemaker .serve .utils .exceptions import ModelBuilderException , LocalModelOutOfMemoryException
17
23
18
- mock_session = MagicMock ()
24
+ MOCK_SESSION = Mock ()
25
+ MOCK_FUNC_NAME = "Mock.deploy"
26
+ MOCK_DJL_CONTAINER = (
27
+ "763104351884.dkr.ecr.us-west-2.amazonaws.com/" "djl-inference:0.25.0-deepspeed0.11.0-cu118"
28
+ )
29
+ MOCK_TGI_CONTAINER = (
30
+ "763104351884.dkr.ecr.us-east-1.amazonaws.com/"
31
+ "huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04"
32
+ )
33
+ MOCK_HUGGINGFACE_ID = "meta-llama/Llama-2-7b-hf"
34
+ MOCK_EXCEPTION = LocalModelOutOfMemoryException ("mock raise ex" )
35
+
36
+
37
+ class ModelBuilderMock :
38
+ def __init__ (self ):
39
+ self .serve_settings = Mock ()
40
+ self .sagemaker_session = MOCK_SESSION
41
+
42
+ @_capture_telemetry (MOCK_FUNC_NAME )
43
+ def mock_deploy (self , mock_exception_func = None ):
44
+ if mock_exception_func :
45
+ mock_exception_func ()
19
46
20
47
21
48
class TestTelemetryLogger (unittest .TestCase ):
22
49
@patch ("sagemaker.serve.utils.telemetry_logger._requests_helper" )
23
50
@patch ("sagemaker.serve.utils.telemetry_logger._get_accountId" )
24
51
def test_log_sucessfully (self , mocked_get_accountId , mocked_request_helper ):
25
- mock_session .boto_session .region_name = "ap-south-1"
52
+ MOCK_SESSION .boto_session .region_name = "ap-south-1"
26
53
mocked_get_accountId .return_value = "testAccountId"
27
- _send_telemetry ("someStatus" , 1 , mock_session )
54
+ _send_telemetry ("someStatus" , 1 , MOCK_SESSION )
28
55
mocked_request_helper .assert_called_with (
29
56
"https://dev-exp-t-ap-south-1.s3.ap-south-1.amazonaws.com/"
30
57
"telemetry?x-accountId=testAccountId&x-mode=1&x-status=someStatus" ,
@@ -34,9 +61,120 @@ def test_log_sucessfully(self, mocked_get_accountId, mocked_request_helper):
34
61
@patch ("sagemaker.serve.utils.telemetry_logger._get_accountId" )
35
62
def test_log_handle_exception (self , mocked_get_accountId ):
36
63
mocked_get_accountId .side_effect = Exception ("Internal error" )
37
- _send_telemetry ("someStatus" , 1 , mock_session )
64
+ _send_telemetry ("someStatus" , 1 , MOCK_SESSION )
38
65
self .assertRaises (Exception )
39
66
67
+ @patch ("sagemaker.serve.utils.telemetry_logger._send_telemetry" )
68
+ def test_capture_telemetry_decorator_djl_success (self , mock_send_telemetry ):
69
+ mock_model_builder = ModelBuilderMock ()
70
+ mock_model_builder .serve_settings .telemetry_opt_out = False
71
+ mock_model_builder .image_uri = MOCK_DJL_CONTAINER
72
+ mock_model_builder .model = MOCK_HUGGINGFACE_ID
73
+ mock_model_builder .mode = Mode .LOCAL_CONTAINER
74
+ mock_model_builder .model_server = ModelServer .DJL_SERVING
75
+
76
+ mock_model_builder .mock_deploy ()
77
+
78
+ expected_extra_str = (
79
+ f"{ MOCK_FUNC_NAME } "
80
+ "&x-modelServer=4"
81
+ "&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118"
82
+ f"&x-modelName={ MOCK_HUGGINGFACE_ID } "
83
+ )
84
+ mock_send_telemetry .assert_called_once_with (
85
+ "1" , 2 , MOCK_SESSION , None , None , expected_extra_str
86
+ )
87
+
88
+ @patch ("sagemaker.serve.utils.telemetry_logger._send_telemetry" )
89
+ def test_capture_telemetry_decorator_tgi_success (self , mock_send_telemetry ):
90
+ mock_model_builder = ModelBuilderMock ()
91
+ mock_model_builder .serve_settings .telemetry_opt_out = False
92
+ mock_model_builder .image_uri = MOCK_TGI_CONTAINER
93
+ mock_model_builder .model = MOCK_HUGGINGFACE_ID
94
+ mock_model_builder .mode = Mode .LOCAL_CONTAINER
95
+ mock_model_builder .model_server = ModelServer .TGI
96
+
97
+ mock_model_builder .mock_deploy ()
98
+
99
+ expected_extra_str = (
100
+ f"{ MOCK_FUNC_NAME } "
101
+ "&x-modelServer=6"
102
+ "&x-imageTag=huggingface-pytorch-inference:2.0.0-transformers4.28.1-cpu-py310-ubuntu20.04"
103
+ f"&x-modelName={ MOCK_HUGGINGFACE_ID } "
104
+ )
105
+ mock_send_telemetry .assert_called_once_with (
106
+ "1" , 2 , MOCK_SESSION , None , None , expected_extra_str
107
+ )
108
+
109
+ @patch ("sagemaker.serve.utils.telemetry_logger._send_telemetry" )
110
+ def test_capture_telemetry_decorator_no_call_when_disabled (self , mock_send_telemetry ):
111
+ mock_model_builder = ModelBuilderMock ()
112
+ mock_model_builder .serve_settings .telemetry_opt_out = True
113
+ mock_model_builder .image_uri = MOCK_DJL_CONTAINER
114
+ mock_model_builder .model = MOCK_HUGGINGFACE_ID
115
+ mock_model_builder .model_server = ModelServer .DJL_SERVING
40
116
41
- if __name__ == "__main__" :
42
- unittest .main ()
117
+ mock_model_builder .mock_deploy ()
118
+
119
+ assert not mock_send_telemetry .called
120
+
121
+ @patch ("sagemaker.serve.utils.telemetry_logger._send_telemetry" )
122
+ def test_capture_telemetry_decorator_handle_exception_success (self , mock_send_telemetry ):
123
+ mock_model_builder = ModelBuilderMock ()
124
+ mock_model_builder .serve_settings .telemetry_opt_out = False
125
+ mock_model_builder .image_uri = MOCK_DJL_CONTAINER
126
+ mock_model_builder .model = MOCK_HUGGINGFACE_ID
127
+ mock_model_builder .mode = Mode .LOCAL_CONTAINER
128
+ mock_model_builder .model_server = ModelServer .DJL_SERVING
129
+
130
+ mock_exception = Mock ()
131
+ mock_exception_obj = MOCK_EXCEPTION
132
+ mock_exception .side_effect = mock_exception_obj
133
+
134
+ with self .assertRaises (ModelBuilderException ) as _ :
135
+ mock_model_builder .mock_deploy (mock_exception )
136
+
137
+ expected_extra_str = (
138
+ f"{ MOCK_FUNC_NAME } "
139
+ "&x-modelServer=4"
140
+ "&x-imageTag=djl-inference:0.25.0-deepspeed0.11.0-cu118"
141
+ f"&x-modelName={ MOCK_HUGGINGFACE_ID } "
142
+ )
143
+ mock_send_telemetry .assert_called_once_with (
144
+ "0" ,
145
+ 2 ,
146
+ MOCK_SESSION ,
147
+ str (mock_exception_obj ),
148
+ mock_exception_obj .__class__ .__name__ ,
149
+ expected_extra_str ,
150
+ )
151
+
152
+ def test_construct_url_with_failure_reason_and_extra_info (self ):
153
+ mock_accountId = "12345678910"
154
+ mock_mode = Mode .LOCAL_CONTAINER
155
+ mock_status = "0"
156
+ mock_failure_reason = str (MOCK_EXCEPTION )
157
+ mock_failure_type = MOCK_EXCEPTION .__class__ .__name__
158
+ mock_extra_info = "mock_extra_info"
159
+ mock_region = "us-west-2"
160
+
161
+ ret_url = _construct_url (
162
+ accountId = mock_accountId ,
163
+ mode = mock_mode ,
164
+ status = mock_status ,
165
+ failure_reason = mock_failure_reason ,
166
+ failure_type = mock_failure_type ,
167
+ extra_info = mock_extra_info ,
168
+ region = mock_region ,
169
+ )
170
+
171
+ expected_base_url = (
172
+ f"https://dev-exp-t-{ mock_region } .s3.{ mock_region } .amazonaws.com/telemetry?"
173
+ f"x-accountId={ mock_accountId } "
174
+ f"&x-mode={ mock_mode } "
175
+ f"&x-status={ mock_status } "
176
+ f"&x-failureReason={ mock_failure_reason } "
177
+ f"&x-failureType={ mock_failure_type } "
178
+ f"&x-extra={ mock_extra_info } "
179
+ )
180
+ self .assertEquals (ret_url , expected_base_url )
0 commit comments