diff --git a/CHANGELOG.md b/CHANGELOG.md index 99416fe44a..38092bf59e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## v2.218.1 (2024-05-03) + +### Bug Fixes and Other Changes + + * Fix UserAgent logging in Python SDK + * chore: release tgi 2.0.1 + * chore: update skipped flaky tests + ## v2.218.0 (2024-05-01) ### Features diff --git a/VERSION b/VERSION index c611a0a1ab..b298acdcc9 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.218.1.dev0 +2.218.2.dev0 diff --git a/src/sagemaker/image_uri_config/huggingface-llm.json b/src/sagemaker/image_uri_config/huggingface-llm.json index 10073338e7..d357367e6e 100644 --- a/src/sagemaker/image_uri_config/huggingface-llm.json +++ b/src/sagemaker/image_uri_config/huggingface-llm.json @@ -12,7 +12,7 @@ "1.2": "1.2.0", "1.3": "1.3.3", "1.4": "1.4.5", - "2.0": "2.0.0" + "2.0": "2.0.1" }, "versions": { "0.6.0": { @@ -578,6 +578,53 @@ "container_version": { "gpu": "cu121-ubuntu22.04" } + }, + "2.0.1": { + "py_versions": [ + "py310" + ], + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ap-southeast-4": "457447274322", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "me-south-1": "217643126080", + "me-central-1": "914824155844", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-isob-east-1": "094389454867", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "tag_prefix": "2.1.1-tgi2.0.1", + "repository": "huggingface-pytorch-tgi-inference", + "container_version": { + "gpu": "cu121-ubuntu22.04" + } } } } diff --git a/src/sagemaker/jumpstart/notebook_utils.py b/src/sagemaker/jumpstart/notebook_utils.py index 52645f89f3..781548b42a 100644 --- a/src/sagemaker/jumpstart/notebook_utils.py +++ b/src/sagemaker/jumpstart/notebook_utils.py @@ -329,9 +329,12 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin return sorted(list(model_id_version_dict.keys())) if not list_old_models: - model_id_version_dict = { - model_id: set([max(versions)]) for model_id, versions in model_id_version_dict.items() - } + for model_id, versions in model_id_version_dict.items(): + try: + model_id_version_dict.update({model_id: set([max(versions)])}) + except TypeError: + versions = [str(v) for v in versions] + model_id_version_dict.update({model_id: set([max(versions)])}) model_id_version_set: Set[Tuple[str, str]] = set() for model_id in model_id_version_dict: diff --git a/src/sagemaker/jumpstart/payload_utils.py b/src/sagemaker/jumpstart/payload_utils.py index 595f801598..e4d31e9c83 100644 --- a/src/sagemaker/jumpstart/payload_utils.py +++ b/src/sagemaker/jumpstart/payload_utils.py @@ -23,7 +23,7 @@ from sagemaker.jumpstart.constants import ( DEFAULT_JUMPSTART_SAGEMAKER_SESSION, ) -from sagemaker.jumpstart.enums import MIMEType +from sagemaker.jumpstart.enums import JumpStartModelType, MIMEType from sagemaker.jumpstart.types import JumpStartSerializablePayload from sagemaker.jumpstart.utils import ( get_jumpstart_content_bucket, @@ -61,6 +61,7 @@ def _construct_payload( tolerate_vulnerable_model: bool = False, tolerate_deprecated_model: bool = False, sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION, + model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, ) -> Optional[JumpStartSerializablePayload]: """Returns example payload from prompt. @@ -83,6 +84,8 @@ def _construct_payload( object, used for SageMaker interactions. If not specified, one is created using the default AWS configuration chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION). + model_type (JumpStartModelType): The type of the model, can be open weights model or + proprietary model. (Default: JumpStartModelType.OPEN_WEIGHTS). Returns: Optional[JumpStartSerializablePayload]: serializable payload with prompt, or None if this feature is unavailable for the specified model. @@ -94,6 +97,7 @@ def _construct_payload( tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, sagemaker_session=sagemaker_session, + model_type=model_type, ) if payloads is None or len(payloads) == 0: return None diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index ab2eeed7f0..7a51d075ae 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1097,7 +1097,7 @@ def __init__( config (Dict[str, Any]): Dictionary representation of the config. base_fields (Dict[str, Any]): - The default base fields that are used to construct the final resolved config. + The default base fields that are used to construct the resolved config. config_components (Dict[str, JumpStartConfigComponent]): The list of components that are used to construct the resolved config. """ diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 5ea3d5f8a1..bf2a736871 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -121,7 +121,7 @@ from sagemaker.deprecations import deprecated_class from sagemaker.enums import EndpointType from sagemaker.inputs import ShuffleConfig, TrainingInput, BatchDataCaptureConfig -from sagemaker.user_agent import prepend_user_agent +from sagemaker.user_agent import get_user_agent_extra_suffix from sagemaker.utils import ( name_from_image, secondary_training_status_changed, @@ -285,6 +285,7 @@ def _initialize( Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client. Sets the region_name. """ + self.boto_session = boto_session or boto3.DEFAULT_SESSION or boto3.Session() self._region_name = self.boto_session.region_name @@ -293,19 +294,30 @@ def _initialize( "Must setup local AWS configuration with a region supported by SageMaker." ) - self.sagemaker_client = sagemaker_client or self.boto_session.client("sagemaker") - prepend_user_agent(self.sagemaker_client) + # Make use of user_agent_extra field of the botocore_config object + # to append SageMaker Python SDK specific user_agent suffix + # to the current User-Agent header value from boto3 + # This config will also make sure that user_agent never fails to log the User-Agent string + # even if boto User-Agent header format is updated in the future + # Ref: https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + botocore_config = botocore.config.Config(user_agent_extra=get_user_agent_extra_suffix()) + + # Create sagemaker_client with the botocore_config object + # This config is customized to append SageMaker Python SDK specific user_agent suffix + self.sagemaker_client = sagemaker_client or self.boto_session.client( + "sagemaker", config=botocore_config + ) if sagemaker_runtime_client is not None: self.sagemaker_runtime_client = sagemaker_runtime_client else: - config = botocore.config.Config(read_timeout=80) + config = botocore.config.Config( + read_timeout=80, user_agent_extra=get_user_agent_extra_suffix() + ) self.sagemaker_runtime_client = self.boto_session.client( "runtime.sagemaker", config=config ) - prepend_user_agent(self.sagemaker_runtime_client) - if sagemaker_featurestore_runtime_client: self.sagemaker_featurestore_runtime_client = sagemaker_featurestore_runtime_client else: @@ -316,8 +328,9 @@ def _initialize( if sagemaker_metrics_client: self.sagemaker_metrics_client = sagemaker_metrics_client else: - self.sagemaker_metrics_client = self.boto_session.client("sagemaker-metrics") - prepend_user_agent(self.sagemaker_metrics_client) + self.sagemaker_metrics_client = self.boto_session.client( + "sagemaker-metrics", config=botocore_config + ) self.s3_client = self.boto_session.client("s3", region_name=self.boto_region_name) self.s3_resource = self.boto_session.resource("s3", region_name=self.boto_region_name) diff --git a/src/sagemaker/user_agent.py b/src/sagemaker/user_agent.py index 8af89696c2..c1b2bcac07 100644 --- a/src/sagemaker/user_agent.py +++ b/src/sagemaker/user_agent.py @@ -13,8 +13,6 @@ """Placeholder docstring""" from __future__ import absolute_import -import platform -import sys import json import os @@ -28,12 +26,6 @@ STUDIO_METADATA_FILE = "/opt/ml/metadata/resource-metadata.json" SDK_VERSION = importlib_metadata.version("sagemaker") -OS_NAME = platform.system() or "UnresolvedOS" -OS_VERSION = platform.release() or "UnresolvedOSVersion" -OS_NAME_VERSION = "{}/{}".format(OS_NAME, OS_VERSION) -PYTHON_VERSION = "Python/{}.{}.{}".format( - sys.version_info.major, sys.version_info.minor, sys.version_info.micro -) def process_notebook_metadata_file(): @@ -63,45 +55,24 @@ def process_studio_metadata_file(): return None -def determine_prefix(user_agent=""): - """Determines the prefix for the user agent string. +def get_user_agent_extra_suffix(): + """Get the user agent extra suffix string specific to SageMaker Python SDK - Args: - user_agent (str): The user agent string to prepend the prefix to. + Adhers to new boto recommended User-Agent 2.0 header format Returns: - str: The user agent string with the prefix prepended. + str: The user agent extra suffix string to be appended """ - prefix = "{}/{}".format(SDK_PREFIX, SDK_VERSION) - - if PYTHON_VERSION not in user_agent: - prefix = "{} {}".format(prefix, PYTHON_VERSION) - - if OS_NAME_VERSION not in user_agent: - prefix = "{} {}".format(prefix, OS_NAME_VERSION) + suffix = "lib/{}#{}".format(SDK_PREFIX, SDK_VERSION) # Get the notebook instance type and prepend it to the user agent string if exists notebook_instance_type = process_notebook_metadata_file() if notebook_instance_type: - prefix = "{} {}/{}".format(prefix, NOTEBOOK_PREFIX, notebook_instance_type) + suffix = "{} md/{}#{}".format(suffix, NOTEBOOK_PREFIX, notebook_instance_type) # Get the studio app type and prepend it to the user agent string if exists studio_app_type = process_studio_metadata_file() if studio_app_type: - prefix = "{} {}/{}".format(prefix, STUDIO_PREFIX, studio_app_type) - - return prefix - - -def prepend_user_agent(client): - """Prepends the user agent string with the SageMaker Python SDK version. - - Args: - client (botocore.client.BaseClient): The client to prepend the user agent string for. - """ - prefix = determine_prefix(client._client_config.user_agent) + suffix = "{} md/{}#{}".format(suffix, STUDIO_PREFIX, studio_app_type) - if client._client_config.user_agent is None: - client._client_config.user_agent = prefix - else: - client._client_config.user_agent = "{} {}".format(prefix, client._client_config.user_agent) + return suffix diff --git a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py index 582e5cf82d..2ef981a109 100644 --- a/tests/unit/sagemaker/image_uris/test_huggingface_llm.py +++ b/tests/unit/sagemaker/image_uris/test_huggingface_llm.py @@ -32,6 +32,7 @@ "1.4.2": "2.1.1-tgi1.4.2-gpu-py310-cu121-ubuntu22.04", "1.4.5": "2.1.1-tgi1.4.5-gpu-py310-cu121-ubuntu22.04", "2.0.0": "2.1.1-tgi2.0.0-gpu-py310-cu121-ubuntu22.04", + "2.0.1": "2.1.1-tgi2.0.1-gpu-py310-cu121-ubuntu22.04", }, "inf2": { "0.0.16": "1.13.1-optimum0.0.16-neuronx-py310-ubuntu22.04", diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index 50f35d19bb..646b672cae 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -3,7 +3,6 @@ from unittest import TestCase from unittest.mock import Mock, patch -import datetime import pytest from sagemaker.jumpstart.constants import ( @@ -17,7 +16,6 @@ get_prototype_manifest, get_prototype_model_spec, ) -from tests.unit.sagemaker.jumpstart.constants import BASE_PROPRIETARY_MANIFEST from sagemaker.jumpstart.enums import JumpStartModelType from sagemaker.jumpstart.notebook_utils import ( _generate_jumpstart_model_versions, @@ -242,7 +240,7 @@ def test_list_jumpstart_models_script_filter( manifest_length = len(get_prototype_manifest()) vals = [True, False] for val in vals: - kwargs = {"filter": f"training_supported == {val}"} + kwargs = {"filter": And(f"training_supported == {val}", "model_type is open_weights")} list_jumpstart_models(**kwargs) assert patched_read_s3_file.call_count == 2 * manifest_length assert patched_get_manifest.call_count == 2 @@ -250,15 +248,17 @@ def test_list_jumpstart_models_script_filter( patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() - kwargs = {"filter": f"training_supported != {val}"} + kwargs = {"filter": And(f"training_supported != {val}", "model_type is open_weights")} list_jumpstart_models(**kwargs) assert patched_read_s3_file.call_count == 2 * manifest_length assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() - - kwargs = {"filter": f"training_supported in {vals}", "list_versions": True} + kwargs = { + "filter": And(f"training_supported != {val}", "model_type is open_weights"), + "list_versions": True, + } assert list_jumpstart_models(**kwargs) == [ ("catboost-classification-model", "1.0.0"), ("huggingface-spc-bert-base-cased", "1.0.0"), @@ -275,7 +275,7 @@ def test_list_jumpstart_models_script_filter( patched_get_manifest.reset_mock() patched_read_s3_file.reset_mock() - kwargs = {"filter": f"training_supported not in {vals}"} + kwargs = {"filter": And(f"training_supported not in {vals}", "model_type is open_weights")} models = list_jumpstart_models(**kwargs) assert [] == models assert patched_read_s3_file.call_count == 2 * manifest_length @@ -514,10 +514,6 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME): list_old_models=False, list_versions=True ) == list_jumpstart_models(list_versions=True) - @pytest.mark.skipif( - datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1), - reason="Contact JumpStart team to fix flaky test.", - ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_vulnerable_models( @@ -543,12 +539,15 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): patched_read_s3_file.side_effect = vulnerable_inference_model_spec num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) - num_prop_specs = len(BASE_PROPRIETARY_MANIFEST) assert [] == list_jumpstart_models( - And("inference_vulnerable is false", "training_vulnerable is false") + And( + "inference_vulnerable is false", + "training_vulnerable is false", + "model_type is open_weights", + ) ) - assert patched_read_s3_file.call_count == num_specs + num_prop_specs + assert patched_read_s3_file.call_count == num_specs assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -557,10 +556,14 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): patched_read_s3_file.side_effect = vulnerable_training_model_spec assert [] == list_jumpstart_models( - And("inference_vulnerable is false", "training_vulnerable is false") + And( + "inference_vulnerable is false", + "training_vulnerable is false", + "model_type is open_weights", + ) ) - assert patched_read_s3_file.call_count == num_specs + num_prop_specs + assert patched_read_s3_file.call_count == num_specs assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() @@ -570,10 +573,6 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs): assert patched_read_s3_file.call_count == 0 - @pytest.mark.skipif( - datetime.datetime.now() < datetime.datetime(year=2024, month=5, day=1), - reason="Contact JumpStart team to fix flaky test.", - ) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file") def test_list_jumpstart_models_deprecated_models( @@ -594,10 +593,11 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str: patched_read_s3_file.side_effect = deprecated_model_spec num_specs = len(PROTOTYPICAL_MODEL_SPECS_DICT) - num_prop_specs = len(BASE_PROPRIETARY_MANIFEST) - assert [] == list_jumpstart_models("deprecated equals false") + assert [] == list_jumpstart_models( + And("deprecated equals false", "model_type is open_weights") + ) - assert patched_read_s3_file.call_count == num_specs + num_prop_specs + assert patched_read_s3_file.call_count == num_specs assert patched_get_manifest.call_count == 2 patched_get_manifest.reset_mock() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 944f22acff..f7dede1ce9 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -43,8 +43,6 @@ from sagemaker.utils import update_list_of_dicts_with_values_from_config from sagemaker.user_agent import ( SDK_PREFIX, - STUDIO_PREFIX, - NOTEBOOK_PREFIX, ) from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from tests.unit import ( @@ -87,15 +85,20 @@ limits={}, ) +SDK_DEFAULT_SUFFIX = f"lib/{SDK_PREFIX}#2.218.0" +NOTEBOOK_SUFFIX = f"{SDK_DEFAULT_SUFFIX} md/AWS-SageMaker-Notebook-Instance#instance_type" +STUDIO_SUFFIX = f"{SDK_DEFAULT_SUFFIX} md/AWS-SageMaker-Studio#app_type" -@pytest.fixture() -def boto_session(): - boto_mock = Mock(name="boto_session", region_name=REGION) +@pytest.fixture +def boto_session(request): + boto_user_agent = "Boto3/1.33.9 md/Botocore#1.33.9 ua/2.0 os/linux#linux-ver md/arch#x86_64 lang/python#3.10.6" + user_agent_suffix = getattr(request, "param", "") + boto_mock = Mock(name="boto_session", region_name=REGION) client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.9.69 Python/3.6.5 Linux/4.14.77-70.82.amzn1.x86_64 Botocore/1.12.69 Resource" - ) + user_agent = f"{boto_user_agent} {SDK_DEFAULT_SUFFIX} {user_agent_suffix}" + with patch("sagemaker.user_agent.get_user_agent_extra_suffix", return_value=user_agent_suffix): + client_mock._client_config.user_agent = user_agent boto_mock.client.return_value = client_mock return boto_mock @@ -887,65 +890,42 @@ def test_delete_model(boto_session): boto_session.client().delete_model.assert_called_with(ModelName=model_name) +@pytest.mark.parametrize("boto_session", [""], indirect=True) def test_user_agent_injected(boto_session): - assert SDK_PREFIX not in boto_session.client("sagemaker")._client_config.user_agent - sess = Session(boto_session) - + expected_user_agent_suffix = "lib/AWS-SageMaker-Python-SDK#2.218.0" for client in [ sess.sagemaker_client, sess.sagemaker_runtime_client, sess.sagemaker_metrics_client, ]: - assert SDK_PREFIX in client._client_config.user_agent - assert NOTEBOOK_PREFIX not in client._client_config.user_agent - assert STUDIO_PREFIX not in client._client_config.user_agent + assert expected_user_agent_suffix in client._client_config.user_agent -@patch("sagemaker.user_agent.process_notebook_metadata_file", return_value="ml.t3.medium") -def test_user_agent_injected_with_nbi( - mock_process_notebook_metadata_file, - boto_session, -): - assert SDK_PREFIX not in boto_session.client("sagemaker")._client_config.user_agent - - sess = Session( - boto_session=boto_session, +@pytest.mark.parametrize("boto_session", [f"{NOTEBOOK_SUFFIX}"], indirect=True) +def test_user_agent_with_notebook_instance_type(boto_session): + sess = Session(boto_session) + expected_user_agent_suffix = ( + "lib/AWS-SageMaker-Python-SDK#2.218.0 md/AWS-SageMaker-Notebook-Instance#instance_type" ) - for client in [ sess.sagemaker_client, sess.sagemaker_runtime_client, sess.sagemaker_metrics_client, ]: - mock_process_notebook_metadata_file.assert_called() - - assert SDK_PREFIX in client._client_config.user_agent - assert NOTEBOOK_PREFIX in client._client_config.user_agent - assert STUDIO_PREFIX not in client._client_config.user_agent + assert expected_user_agent_suffix in client._client_config.user_agent -@patch("sagemaker.user_agent.process_studio_metadata_file", return_value="dymmy-app-type") -def test_user_agent_injected_with_studio_app_type( - mock_process_studio_metadata_file, - boto_session, -): - assert SDK_PREFIX not in boto_session.client("sagemaker")._client_config.user_agent - - sess = Session( - boto_session=boto_session, - ) - +@pytest.mark.parametrize("boto_session", [f"{STUDIO_SUFFIX}"], indirect=True) +def test_user_agent_with_studio_app_type(boto_session): + sess = Session(boto_session) + expected_user_agent = "lib/AWS-SageMaker-Python-SDK#2.218.0 md/AWS-SageMaker-Studio#app_type" for client in [ sess.sagemaker_client, sess.sagemaker_runtime_client, sess.sagemaker_metrics_client, ]: - mock_process_studio_metadata_file.assert_called() - - assert SDK_PREFIX in client._client_config.user_agent - assert NOTEBOOK_PREFIX not in client._client_config.user_agent - assert STUDIO_PREFIX in client._client_config.user_agent + assert expected_user_agent in client._client_config.user_agent def test_training_input_all_defaults(): diff --git a/tests/unit/test_user_agent.py b/tests/unit/test_user_agent.py index c116fef951..fb46988e7b 100644 --- a/tests/unit/test_user_agent.py +++ b/tests/unit/test_user_agent.py @@ -13,20 +13,17 @@ from __future__ import absolute_import import json -from mock import MagicMock, patch, mock_open +from mock import patch, mock_open from sagemaker.user_agent import ( SDK_PREFIX, SDK_VERSION, - PYTHON_VERSION, - OS_NAME_VERSION, NOTEBOOK_PREFIX, STUDIO_PREFIX, process_notebook_metadata_file, process_studio_metadata_file, - determine_prefix, - prepend_user_agent, + get_user_agent_extra_suffix, ) @@ -60,45 +57,18 @@ def test_process_studio_metadata_file_not_exists(tmp_path): assert process_studio_metadata_file() is None -# Test determine_prefix function -def test_determine_prefix_notebook_instance_type(monkeypatch): - monkeypatch.setattr( - "sagemaker.user_agent.process_notebook_metadata_file", lambda: "instance_type" - ) - assert ( - determine_prefix() - == f"{SDK_PREFIX}/{SDK_VERSION} {PYTHON_VERSION} {OS_NAME_VERSION} {NOTEBOOK_PREFIX}/instance_type" - ) - - -def test_determine_prefix_studio_app_type(monkeypatch): - monkeypatch.setattr( - "sagemaker.user_agent.process_studio_metadata_file", lambda: "studio_app_type" - ) - assert ( - determine_prefix() - == f"{SDK_PREFIX}/{SDK_VERSION} {PYTHON_VERSION} {OS_NAME_VERSION} {STUDIO_PREFIX}/studio_app_type" - ) - - -def test_determine_prefix_no_metadata(monkeypatch): - monkeypatch.setattr("sagemaker.user_agent.process_notebook_metadata_file", lambda: None) - monkeypatch.setattr("sagemaker.user_agent.process_studio_metadata_file", lambda: None) - assert determine_prefix() == f"{SDK_PREFIX}/{SDK_VERSION} {PYTHON_VERSION} {OS_NAME_VERSION}" - - -# Test prepend_user_agent function -def test_prepend_user_agent_existing_user_agent(monkeypatch): - client = MagicMock() - client._client_config.user_agent = "existing_user_agent" - monkeypatch.setattr("sagemaker.user_agent.determine_prefix", lambda _: "prefix") - prepend_user_agent(client) - assert client._client_config.user_agent == "prefix existing_user_agent" - - -def test_prepend_user_agent_no_user_agent(monkeypatch): - client = MagicMock() - client._client_config.user_agent = None - monkeypatch.setattr("sagemaker.user_agent.determine_prefix", lambda _: "prefix") - prepend_user_agent(client) - assert client._client_config.user_agent == "prefix" +# Test get_user_agent_extra_suffix function +def test_get_user_agent_extra_suffix(): + assert get_user_agent_extra_suffix() == f"lib/{SDK_PREFIX}#{SDK_VERSION}" + + with patch("sagemaker.user_agent.process_notebook_metadata_file", return_value="instance_type"): + assert ( + get_user_agent_extra_suffix() + == f"lib/{SDK_PREFIX}#{SDK_VERSION} md/{NOTEBOOK_PREFIX}#instance_type" + ) + + with patch("sagemaker.user_agent.process_studio_metadata_file", return_value="studio_type"): + assert ( + get_user_agent_extra_suffix() + == f"lib/{SDK_PREFIX}#{SDK_VERSION} md/{STUDIO_PREFIX}#studio_type" + )