Skip to content

Commit 24c2e9d

Browse files
evakravijiapinw
authored andcommitted
feat: jumpstart telemetry (aws#4697)
* feat: jumpstart telemetry * chore: add unit tests * fix: unit tests * fix: pydocstyle * fix: user agent to comply with RFC style * chore: add unit test for http headers * fix: flake8
1 parent 13b631d commit 24c2e9d

File tree

7 files changed

+278
-53
lines changed

7 files changed

+278
-53
lines changed

src/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737

3838
ENV_VARIABLE_DISABLE_JUMPSTART_LOGGING = "DISABLE_JUMPSTART_LOGGING"
39+
ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY = "DISABLE_JUMPSTART_TELEMETRY"
3940

4041
JUMPSTART_LAUNCHED_REGIONS: Set[JumpStartLaunchedRegionInfo] = set(
4142
[

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
_model_supports_training_model_uri,
4545
)
4646
from sagemaker.jumpstart.constants import (
47-
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
4847
JUMPSTART_DEFAULT_REGION_NAME,
4948
JUMPSTART_LOGGER,
5049
TRAINING_ENTRY_POINT_SCRIPT_NAME,
@@ -63,6 +62,7 @@
6362
from sagemaker.jumpstart.utils import (
6463
add_jumpstart_model_id_version_tags,
6564
get_eula_message,
65+
get_default_jumpstart_session_with_user_agent_suffix,
6666
update_dict_if_key_not_present,
6767
resolve_estimator_sagemaker_config_field,
6868
verify_model_region_and_return_specs,
@@ -403,7 +403,12 @@ def _add_region_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs:
403403

404404
def _add_sagemaker_session_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs:
405405
"""Sets session in kwargs based on default or override, returns full kwargs."""
406-
kwargs.sagemaker_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION
406+
kwargs.sagemaker_session = (
407+
kwargs.sagemaker_session
408+
or get_default_jumpstart_session_with_user_agent_suffix(
409+
kwargs.model_id, kwargs.model_version
410+
)
411+
)
407412
return kwargs
408413

409414

src/sagemaker/jumpstart/factory/model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
)
3030
from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
3131
from sagemaker.jumpstart.constants import (
32-
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3332
INFERENCE_ENTRY_POINT_SCRIPT_NAME,
3433
JUMPSTART_DEFAULT_REGION_NAME,
3534
JUMPSTART_LOGGER,
@@ -45,6 +44,7 @@
4544
)
4645
from sagemaker.jumpstart.utils import (
4746
add_jumpstart_model_id_version_tags,
47+
get_default_jumpstart_session_with_user_agent_suffix,
4848
update_dict_if_key_not_present,
4949
resolve_model_sagemaker_config_field,
5050
verify_model_region_and_return_specs,
@@ -140,7 +140,12 @@ def _add_sagemaker_session_to_kwargs(
140140
kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs]
141141
) -> JumpStartModelInitKwargs:
142142
"""Sets session in kwargs based on default or override, returns full kwargs."""
143-
kwargs.sagemaker_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION
143+
kwargs.sagemaker_session = (
144+
kwargs.sagemaker_session
145+
or get_default_jumpstart_session_with_user_agent_suffix(
146+
kwargs.model_id, kwargs.model_version
147+
)
148+
)
144149
return kwargs
145150

146151

src/sagemaker/jumpstart/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains utilities related to SageMaker JumpStart."""
1414
from __future__ import absolute_import
15+
from copy import copy
1516
import logging
1617
import os
1718
from typing import Any, Dict, List, Set, Optional, Tuple, Union
1819
from urllib.parse import urlparse
1920
import boto3
2021
from packaging.version import Version
22+
import botocore
2123
import sagemaker
2224
from sagemaker.config.config_schema import (
2325
MODEL_ENABLE_NETWORK_ISOLATION_PATH,
@@ -46,6 +48,7 @@
4648
from sagemaker.config import load_sagemaker_config
4749
from sagemaker.utils import resolve_value_from_config, TagsDict
4850
from sagemaker.workflow import is_pipeline_variable
51+
from sagemaker.user_agent import get_user_agent_extra_suffix
4952

5053

5154
def get_jumpstart_launched_regions_message() -> str:
@@ -982,3 +985,39 @@ def get_jumpstart_configs(
982985
if metadata_configs
983986
else {}
984987
)
988+
989+
990+
def get_jumpstart_user_agent_extra_suffix(model_id: str, model_version: str) -> str:
991+
"""Returns the model-specific user agent string to be added to requests."""
992+
sagemaker_python_sdk_headers = get_user_agent_extra_suffix()
993+
jumpstart_specific_suffix = f"md/js_model_id#{model_id} md/js_model_ver#{model_version}"
994+
return (
995+
sagemaker_python_sdk_headers
996+
if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None)
997+
else f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}"
998+
)
999+
1000+
1001+
def get_default_jumpstart_session_with_user_agent_suffix(
1002+
model_id: str, model_version: str
1003+
) -> Session:
1004+
"""Returns default JumpStart SageMaker Session with model-specific user agent suffix."""
1005+
botocore_session = botocore.session.get_session()
1006+
botocore_config = botocore.config.Config(
1007+
user_agent_extra=get_jumpstart_user_agent_extra_suffix(model_id, model_version),
1008+
)
1009+
botocore_session.set_default_client_config(botocore_config)
1010+
# shallow copy to not affect default session constant
1011+
session = copy(constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION)
1012+
session.boto_session = boto3.Session(
1013+
region_name=constants.JUMPSTART_DEFAULT_REGION_NAME, botocore_session=botocore_session
1014+
)
1015+
session.sagemaker_client = boto3.client(
1016+
"sagemaker", region_name=constants.JUMPSTART_DEFAULT_REGION_NAME, config=botocore_config
1017+
)
1018+
session.sagemaker_runtime_client = boto3.client(
1019+
"sagemaker-runtime",
1020+
region_name=constants.JUMPSTART_DEFAULT_REGION_NAME,
1021+
config=botocore_config,
1022+
)
1023+
return session

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 87 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,12 @@ class EstimatorTest(unittest.TestCase):
6767
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER")
6868
@mock.patch("sagemaker.utils.sagemaker_timestamp")
6969
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
70-
@mock.patch("sagemaker.jumpstart.factory.model.Session")
71-
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
70+
@mock.patch(
71+
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
72+
)
73+
@mock.patch(
74+
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix"
75+
)
7276
@mock.patch("sagemaker.jumpstart.utils.validate_model_id_and_get_type")
7377
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
7478
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
@@ -193,8 +197,12 @@ def test_non_prepacked(
193197
)
194198

195199
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
196-
@mock.patch("sagemaker.jumpstart.factory.model.Session")
197-
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
200+
@mock.patch(
201+
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
202+
)
203+
@mock.patch(
204+
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix"
205+
)
198206
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
199207
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
200208
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
@@ -293,8 +301,12 @@ def test_prepacked(
293301

294302
@mock.patch("sagemaker.utils.sagemaker_timestamp")
295303
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
296-
@mock.patch("sagemaker.jumpstart.factory.model.Session")
297-
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
304+
@mock.patch(
305+
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
306+
)
307+
@mock.patch(
308+
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix"
309+
)
298310
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
299311
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
300312
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
@@ -411,8 +423,12 @@ def test_gated_model_s3_uri(
411423
)
412424
@mock.patch("sagemaker.utils.sagemaker_timestamp")
413425
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
414-
@mock.patch("sagemaker.jumpstart.factory.model.Session")
415-
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
426+
@mock.patch(
427+
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
428+
)
429+
@mock.patch(
430+
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix"
431+
)
416432
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
417433
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
418434
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
@@ -559,8 +575,12 @@ def test_gated_model_non_model_package_s3_uri(
559575

560576
@mock.patch("sagemaker.utils.sagemaker_timestamp")
561577
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
562-
@mock.patch("sagemaker.jumpstart.factory.model.Session")
563-
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
578+
@mock.patch(
579+
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
580+
)
581+
@mock.patch(
582+
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix"
583+
)
564584
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
565585
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
566586
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
@@ -751,8 +771,12 @@ def test_estimator_use_kwargs(self):
751771
@mock.patch("sagemaker.jumpstart.factory.estimator.environment_variables.retrieve_default")
752772
@mock.patch("sagemaker.utils.sagemaker_timestamp")
753773
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
754-
@mock.patch("sagemaker.jumpstart.factory.model.Session")
755-
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
774+
@mock.patch(
775+
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
776+
)
777+
@mock.patch(
778+
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix"
779+
)
756780
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
757781
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
758782
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
@@ -1160,8 +1184,12 @@ def test_validate_model_id_and_get_type(
11601184

11611185
@mock.patch("sagemaker.jumpstart.estimator.get_default_predictor")
11621186
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
1163-
@mock.patch("sagemaker.jumpstart.factory.model.Session")
1164-
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
1187+
@mock.patch(
1188+
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
1189+
)
1190+
@mock.patch(
1191+
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix"
1192+
)
11651193
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
11661194
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
11671195
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
@@ -1219,8 +1247,12 @@ def test_no_predictor_returns_default_predictor(
12191247

12201248
@mock.patch("sagemaker.jumpstart.estimator.get_default_predictor")
12211249
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
1222-
@mock.patch("sagemaker.jumpstart.factory.model.Session")
1223-
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
1250+
@mock.patch(
1251+
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
1252+
)
1253+
@mock.patch(
1254+
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix"
1255+
)
12241256
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
12251257
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
12261258
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
@@ -1269,8 +1301,12 @@ def test_no_predictor_yes_async_inference_config(
12691301

12701302
@mock.patch("sagemaker.jumpstart.estimator.get_default_predictor")
12711303
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
1272-
@mock.patch("sagemaker.jumpstart.factory.model.Session")
1273-
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
1304+
@mock.patch(
1305+
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
1306+
)
1307+
@mock.patch(
1308+
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix"
1309+
)
12741310
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
12751311
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
12761312
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
@@ -1321,8 +1357,12 @@ def test_yes_predictor_returns_unmodified_predictor(
13211357
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
13221358
@mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training")
13231359
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning")
1324-
@mock.patch("sagemaker.jumpstart.factory.model.Session")
1325-
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
1360+
@mock.patch(
1361+
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
1362+
)
1363+
@mock.patch(
1364+
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix"
1365+
)
13261366
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
13271367
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
13281368
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
@@ -1375,8 +1415,12 @@ def test_incremental_training_with_unsupported_model_logs_warning(
13751415
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
13761416
@mock.patch("sagemaker.jumpstart.factory.estimator._model_supports_incremental_training")
13771417
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER.warning")
1378-
@mock.patch("sagemaker.jumpstart.factory.model.Session")
1379-
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
1418+
@mock.patch(
1419+
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
1420+
)
1421+
@mock.patch(
1422+
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix"
1423+
)
13801424
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
13811425
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
13821426
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
@@ -1425,8 +1469,12 @@ def test_incremental_training_with_supported_model_doesnt_log_warning(
14251469

14261470
@mock.patch("sagemaker.utils.sagemaker_timestamp")
14271471
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
1428-
@mock.patch("sagemaker.jumpstart.factory.model.Session")
1429-
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
1472+
@mock.patch(
1473+
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
1474+
)
1475+
@mock.patch(
1476+
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix"
1477+
)
14301478
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
14311479
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
14321480
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
@@ -1486,8 +1534,12 @@ def test_estimator_sets_different_inference_instance_depending_on_training_insta
14861534

14871535
@mock.patch("sagemaker.utils.sagemaker_timestamp")
14881536
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
1489-
@mock.patch("sagemaker.jumpstart.factory.model.Session")
1490-
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
1537+
@mock.patch(
1538+
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
1539+
)
1540+
@mock.patch(
1541+
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix"
1542+
)
14911543
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
14921544
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
14931545
@mock.patch("sagemaker.jumpstart.estimator.Estimator.fit")
@@ -1564,10 +1616,11 @@ def test_training_passes_role_to_deploy(
15641616
@mock.patch("sagemaker.utils.sagemaker_timestamp")
15651617
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
15661618
@mock.patch(
1567-
"sagemaker.jumpstart.factory.model.DEFAULT_JUMPSTART_SAGEMAKER_SESSION", sagemaker_session
1619+
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix",
1620+
sagemaker_session,
15681621
)
15691622
@mock.patch(
1570-
"sagemaker.jumpstart.factory.estimator.DEFAULT_JUMPSTART_SAGEMAKER_SESSION",
1623+
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix",
15711624
sagemaker_session,
15721625
)
15731626
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
@@ -1645,7 +1698,9 @@ def test_training_passes_session_to_deploy(
16451698
@mock.patch("sagemaker.jumpstart.estimator.Estimator.deploy")
16461699
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
16471700
@mock.patch("sagemaker.jumpstart.factory.estimator._retrieve_estimator_init_kwargs")
1648-
@mock.patch("sagemaker.jumpstart.factory.estimator.Session")
1701+
@mock.patch(
1702+
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix"
1703+
)
16491704
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
16501705
@mock.patch("sagemaker.jumpstart.estimator.JumpStartModelsAccessor.reset_cache")
16511706
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)
@@ -1725,7 +1780,9 @@ def test_model_id_not_found_refeshes_cache_training(
17251780
)
17261781

17271782
@mock.patch("sagemaker.jumpstart.estimator.validate_model_id_and_get_type")
1728-
@mock.patch("sagemaker.jumpstart.factory.model.Session")
1783+
@mock.patch(
1784+
"sagemaker.jumpstart.factory.estimator.get_default_jumpstart_session_with_user_agent_suffix"
1785+
)
17291786
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
17301787
@mock.patch("sagemaker.jumpstart.estimator.Estimator.__init__")
17311788
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_DEFAULT_REGION_NAME", region)

0 commit comments

Comments
 (0)