From 25f16ef757e7874b925b426ecf5d1130206d3c2c Mon Sep 17 00:00:00 2001 From: Roja Reddy Sareddy Date: Tue, 28 Jan 2025 23:57:10 -0800 Subject: [PATCH 01/11] change: Allow telemetry only in supported regions --- src/sagemaker/telemetry/constants.py | 35 ++++++++++++++++++ src/sagemaker/telemetry/telemetry_logging.py | 16 +++++++-- .../telemetry/test_telemetry_logging.py | 36 +++++++++++++++++++ 3 files changed, 85 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/telemetry/constants.py b/src/sagemaker/telemetry/constants.py index 2108ff9fd6..28bb758b05 100644 --- a/src/sagemaker/telemetry/constants.py +++ b/src/sagemaker/telemetry/constants.py @@ -42,3 +42,38 @@ class Status(Enum): def __str__(self): # pylint: disable=E0307 """Return the status name.""" return self.name + + +class Region(str, Enum): + # Classic + US_EAST_1 = "us-east-1" # IAD + US_EAST_2 = "us-east-2" # CMH + US_WEST_1 = "us-west-1" # SFO + US_WEST_2 = "us-west-2" # PDX + AP_NORTHEAST_1 = "ap-northeast-1" # NRT + AP_NORTHEAST_2 = "ap-northeast-2" # ICN + AP_NORTHEAST_3 = "ap-northeast-3" # KIX + AP_SOUTH_1 = "ap-south-1" # BOM + AP_SOUTHEAST_1 = "ap-southeast-1" # SIN + AP_SOUTHEAST_2 = "ap-southeast-2" # SYD + CA_CENTRAL_1 = "ca-central-1" # YUL + EU_CENTRAL_1 = "eu-central-1" # FRA + EU_NORTH_1 = "eu-north-1" # ARN + EU_WEST_1 = "eu-west-1" # DUB + EU_WEST_2 = "eu-west-2" # LHR + EU_WEST_3 = "eu-west-3" # CDG + SA_EAST_1 = "sa-east-1" # GRU + # Opt-in + AP_EAST_1 = "ap-east-1" # HKG + AP_SOUTHEAST_3 = "ap-southeast-3" # CGK + AF_SOUTH_1 = "af-south-1" # CPT + EU_SOUTH_1 = "eu-south-1" # MXP + ME_SOUTH_1 = "me-south-1" # BAH + MX_CENTRAL_1 = "mx-central-1" # QRO + AP_SOUTHEAST_7 = "ap-southeast-7" # BKK + AP_SOUTH_2 = "ap-south-2" # HYD + AP_SOUTHEAST_4 = "ap-southeast-4" # MEL + EU_CENTRAL_2 = "eu-central-2" # ZRH + EU_SOUTH_2 = "eu-south-2" # ZAZ + IL_CENTRAL_1 = "il-central-1" # TLV + ME_CENTRAL_1 = "me-central-1" # DXB diff --git a/src/sagemaker/telemetry/telemetry_logging.py b/src/sagemaker/telemetry/telemetry_logging.py index b45550b2c2..55c0e205d9 100644 --- a/src/sagemaker/telemetry/telemetry_logging.py +++ b/src/sagemaker/telemetry/telemetry_logging.py @@ -27,6 +27,7 @@ from sagemaker.telemetry.constants import ( Feature, Status, + Region, DEFAULT_AWS_REGION, ) from sagemaker.user_agent import SDK_VERSION, process_studio_metadata_file @@ -189,8 +190,19 @@ def _send_telemetry_request( """Make GET request to an empty object in S3 bucket""" try: accountId = _get_accountId(session) if session else "NotAvailable" - # telemetry will be sent to us-west-2 if no session availale - region = _get_region_or_default(session) if session else DEFAULT_AWS_REGION + + # Validate region if session exists + if session: + region = _get_region_or_default(session) + try: + Region(region) + except ValueError: + logger.debug( + "Region not found in supported regions. Telemetry request will not be emitted." + ) + return + else: # telemetry will be sent to us-west-2 if no session available + region = DEFAULT_AWS_REGION url = _construct_url( accountId, region, diff --git a/tests/unit/sagemaker/telemetry/test_telemetry_logging.py b/tests/unit/sagemaker/telemetry/test_telemetry_logging.py index 9107256b5b..bd8db82a16 100644 --- a/tests/unit/sagemaker/telemetry/test_telemetry_logging.py +++ b/tests/unit/sagemaker/telemetry/test_telemetry_logging.py @@ -300,3 +300,39 @@ def test_get_default_sagemaker_session_with_no_region(self): assert "Must setup local AWS configuration with a region supported by SageMaker." in str( context.exception ) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + @patch("sagemaker.telemetry.telemetry_logging._get_region_or_default") + def test_send_telemetry_request_valid_region(self, mock_get_region, mock_get_accountId): + """Test to verify telemetry request is sent when region is valid""" + mock_get_accountId.return_value = "testAccountId" + mock_session = MagicMock() + + # Test with valid region + mock_get_region.return_value = "us-east-1" + with patch( + "sagemaker.telemetry.telemetry_logging._requests_helper" + ) as mock_requests_helper: + _send_telemetry_request(1, [1, 2], mock_session) + # Assert telemetry request was sent + mock_requests_helper.assert_called_once_with( + "https://sm-pysdk-t-us-east-1.s3.us-east-1.amazonaws.com/telemetry?" + "x-accountId=testAccountId&x-status=1&x-feature=1,2", + 2, + ) + + @patch("sagemaker.telemetry.telemetry_logging._get_accountId") + @patch("sagemaker.telemetry.telemetry_logging._get_region_or_default") + def test_send_telemetry_request_invalid_region(self, mock_get_region, mock_get_accountId): + """Test to verify telemetry request is not sent when region is invalid""" + mock_get_accountId.return_value = "testAccountId" + mock_session = MagicMock() + + # Test with invalid region + mock_get_region.return_value = "invalid-region" + with patch( + "sagemaker.telemetry.telemetry_logging._requests_helper" + ) as mock_requests_helper: + _send_telemetry_request(1, [1, 2], mock_session) + # Assert telemetry request was not sent + mock_requests_helper.assert_not_called() From 0ed85d67e6cfd8bfe39529a12aaa5dfe43785835 Mon Sep 17 00:00:00 2001 From: Roja Reddy Sareddy Date: Tue, 28 Jan 2025 23:57:10 -0800 Subject: [PATCH 02/11] change: Allow telemetry only in supported regions --- src/sagemaker/telemetry/constants.py | 37 ++++++++++++++++++++ src/sagemaker/telemetry/telemetry_logging.py | 2 +- 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/telemetry/constants.py b/src/sagemaker/telemetry/constants.py index 28bb758b05..a18a4a4a0f 100644 --- a/src/sagemaker/telemetry/constants.py +++ b/src/sagemaker/telemetry/constants.py @@ -44,6 +44,43 @@ def __str__(self): # pylint: disable=E0307 return self.name +class Region(str, Enum): + """Telemetry: List of all supported AWS regions.""" + + # Classic + US_EAST_1 = "us-east-1" # IAD + US_EAST_2 = "us-east-2" # CMH + US_WEST_1 = "us-west-1" # SFO + US_WEST_2 = "us-west-2" # PDX + AP_NORTHEAST_1 = "ap-northeast-1" # NRT + AP_NORTHEAST_2 = "ap-northeast-2" # ICN + AP_NORTHEAST_3 = "ap-northeast-3" # KIX + AP_SOUTH_1 = "ap-south-1" # BOM + AP_SOUTHEAST_1 = "ap-southeast-1" # SIN + AP_SOUTHEAST_2 = "ap-southeast-2" # SYD + CA_CENTRAL_1 = "ca-central-1" # YUL + EU_CENTRAL_1 = "eu-central-1" # FRA + EU_NORTH_1 = "eu-north-1" # ARN + EU_WEST_1 = "eu-west-1" # DUB + EU_WEST_2 = "eu-west-2" # LHR + EU_WEST_3 = "eu-west-3" # CDG + SA_EAST_1 = "sa-east-1" # GRU + # Opt-in + AP_EAST_1 = "ap-east-1" # HKG + AP_SOUTHEAST_3 = "ap-southeast-3" # CGK + AF_SOUTH_1 = "af-south-1" # CPT + EU_SOUTH_1 = "eu-south-1" # MXP + ME_SOUTH_1 = "me-south-1" # BAH + MX_CENTRAL_1 = "mx-central-1" # QRO + AP_SOUTHEAST_7 = "ap-southeast-7" # BKK + AP_SOUTH_2 = "ap-south-2" # HYD + AP_SOUTHEAST_4 = "ap-southeast-4" # MEL + EU_CENTRAL_2 = "eu-central-2" # ZRH + EU_SOUTH_2 = "eu-south-2" # ZAZ + IL_CENTRAL_1 = "il-central-1" # TLV + ME_CENTRAL_1 = "me-central-1" # DXB + + class Region(str, Enum): # Classic US_EAST_1 = "us-east-1" # IAD diff --git a/src/sagemaker/telemetry/telemetry_logging.py b/src/sagemaker/telemetry/telemetry_logging.py index 55c0e205d9..887a574ca1 100644 --- a/src/sagemaker/telemetry/telemetry_logging.py +++ b/src/sagemaker/telemetry/telemetry_logging.py @@ -197,7 +197,7 @@ def _send_telemetry_request( try: Region(region) except ValueError: - logger.debug( + logger.warning( "Region not found in supported regions. Telemetry request will not be emitted." ) return From b69ffcb952948626166c35ec4264d6fff8a2ce17 Mon Sep 17 00:00:00 2001 From: Roja Reddy Sareddy Date: Tue, 28 Jan 2025 23:57:10 -0800 Subject: [PATCH 03/11] change: Allow telemetry only in supported regions --- src/sagemaker/telemetry/constants.py | 34 ---------------------------- 1 file changed, 34 deletions(-) diff --git a/src/sagemaker/telemetry/constants.py b/src/sagemaker/telemetry/constants.py index a18a4a4a0f..d6f19dc3d2 100644 --- a/src/sagemaker/telemetry/constants.py +++ b/src/sagemaker/telemetry/constants.py @@ -80,37 +80,3 @@ class Region(str, Enum): IL_CENTRAL_1 = "il-central-1" # TLV ME_CENTRAL_1 = "me-central-1" # DXB - -class Region(str, Enum): - # Classic - US_EAST_1 = "us-east-1" # IAD - US_EAST_2 = "us-east-2" # CMH - US_WEST_1 = "us-west-1" # SFO - US_WEST_2 = "us-west-2" # PDX - AP_NORTHEAST_1 = "ap-northeast-1" # NRT - AP_NORTHEAST_2 = "ap-northeast-2" # ICN - AP_NORTHEAST_3 = "ap-northeast-3" # KIX - AP_SOUTH_1 = "ap-south-1" # BOM - AP_SOUTHEAST_1 = "ap-southeast-1" # SIN - AP_SOUTHEAST_2 = "ap-southeast-2" # SYD - CA_CENTRAL_1 = "ca-central-1" # YUL - EU_CENTRAL_1 = "eu-central-1" # FRA - EU_NORTH_1 = "eu-north-1" # ARN - EU_WEST_1 = "eu-west-1" # DUB - EU_WEST_2 = "eu-west-2" # LHR - EU_WEST_3 = "eu-west-3" # CDG - SA_EAST_1 = "sa-east-1" # GRU - # Opt-in - AP_EAST_1 = "ap-east-1" # HKG - AP_SOUTHEAST_3 = "ap-southeast-3" # CGK - AF_SOUTH_1 = "af-south-1" # CPT - EU_SOUTH_1 = "eu-south-1" # MXP - ME_SOUTH_1 = "me-south-1" # BAH - MX_CENTRAL_1 = "mx-central-1" # QRO - AP_SOUTHEAST_7 = "ap-southeast-7" # BKK - AP_SOUTH_2 = "ap-south-2" # HYD - AP_SOUTHEAST_4 = "ap-southeast-4" # MEL - EU_CENTRAL_2 = "eu-central-2" # ZRH - EU_SOUTH_2 = "eu-south-2" # ZAZ - IL_CENTRAL_1 = "il-central-1" # TLV - ME_CENTRAL_1 = "me-central-1" # DXB From 8d7f4a8e3c1645b4d892479cbcbd723951e77081 Mon Sep 17 00:00:00 2001 From: Roja Reddy Sareddy Date: Wed, 29 Jan 2025 13:03:28 -0800 Subject: [PATCH 04/11] change: Allow telemetry only in supported regions --- src/sagemaker/telemetry/constants.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sagemaker/telemetry/constants.py b/src/sagemaker/telemetry/constants.py index d6f19dc3d2..cb83a78279 100644 --- a/src/sagemaker/telemetry/constants.py +++ b/src/sagemaker/telemetry/constants.py @@ -79,4 +79,3 @@ class Region(str, Enum): EU_SOUTH_2 = "eu-south-2" # ZAZ IL_CENTRAL_1 = "il-central-1" # TLV ME_CENTRAL_1 = "me-central-1" # DXB - From dadbb220d42205b4658dde8e09861a4b72389507 Mon Sep 17 00:00:00 2001 From: Roja Reddy Sareddy Date: Thu, 30 Jan 2025 12:18:59 -0800 Subject: [PATCH 05/11] change: Allow telemetry only in supported regions --- src/sagemaker/telemetry/telemetry_logging.py | 22 +++++++++----------- 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/sagemaker/telemetry/telemetry_logging.py b/src/sagemaker/telemetry/telemetry_logging.py index 887a574ca1..b0ecedee4c 100644 --- a/src/sagemaker/telemetry/telemetry_logging.py +++ b/src/sagemaker/telemetry/telemetry_logging.py @@ -190,19 +190,16 @@ def _send_telemetry_request( """Make GET request to an empty object in S3 bucket""" try: accountId = _get_accountId(session) if session else "NotAvailable" + region = _get_region_or_default(session) + + try: + Region(region) # Validate the region + except ValueError: + logger.warning( + "Region not found in supported regions. Telemetry request will not be emitted." + ) + return - # Validate region if session exists - if session: - region = _get_region_or_default(session) - try: - Region(region) - except ValueError: - logger.warning( - "Region not found in supported regions. Telemetry request will not be emitted." - ) - return - else: # telemetry will be sent to us-west-2 if no session available - region = DEFAULT_AWS_REGION url = _construct_url( accountId, region, @@ -280,6 +277,7 @@ def _get_region_or_default(session): def _get_default_sagemaker_session(): """Return the default sagemaker session""" + boto_session = boto3.Session(region_name=DEFAULT_AWS_REGION) sagemaker_session = Session(boto_session=boto_session) From 7775c635e26ba11045fe1dccb020b44c86c6cf83 Mon Sep 17 00:00:00 2001 From: Roja Reddy Sareddy Date: Mon, 24 Feb 2025 09:34:28 -0800 Subject: [PATCH 06/11] documentation: Removed a line about python version requirements of training script which can misguide users.Training script can be of latest version based on the support provided by framework_version of the container --- doc/frameworks/pytorch/using_pytorch.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/doc/frameworks/pytorch/using_pytorch.rst b/doc/frameworks/pytorch/using_pytorch.rst index c50376920e..4141dd84db 100644 --- a/doc/frameworks/pytorch/using_pytorch.rst +++ b/doc/frameworks/pytorch/using_pytorch.rst @@ -28,8 +28,6 @@ To train a PyTorch model by using the SageMaker Python SDK: Prepare a PyTorch Training Script ================================= -Your PyTorch training script must be a Python 3.6 compatible source file. - Prepare your script in a separate source file than the notebook, terminal session, or source file you're using to submit the script to SageMaker via a ``PyTorch`` Estimator. This will be discussed in further detail below. From 58f8746ea537cf5918ead6515944657b4ce22356 Mon Sep 17 00:00:00 2001 From: Roja Reddy Sareddy Date: Wed, 12 Mar 2025 08:49:04 -0700 Subject: [PATCH 07/11] feature: Enabled update_endpoint through model_builder --- src/sagemaker/huggingface/model.py | 5 + src/sagemaker/model.py | 53 +++++-- src/sagemaker/serve/builder/model_builder.py | 15 +- src/sagemaker/session.py | 36 +++++ src/sagemaker/tensorflow/model.py | 2 + tests/unit/sagemaker/model/test_deploy.py | 140 ++++++++++++++++++ .../serve/builder/test_model_builder.py | 87 ++++++++++- 7 files changed, 320 insertions(+), 18 deletions(-) diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 05b981d21b..f4b44fc057 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -218,6 +218,7 @@ def deploy( container_startup_health_check_timeout=None, inference_recommendation_id=None, explainer_config=None, + update_endpoint: Optional[bool] = False, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -296,6 +297,9 @@ def deploy( would like to deploy the model and endpoint with recommended parameters. explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability configuration for use with Amazon SageMaker Clarify. (default: None) + update_endpoint (Optional[bool]): Flag to update the model in an existing Amazon SageMaker endpoint. + If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources + corresponding to the previous EndpointConfig. Default: False Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or @@ -335,6 +339,7 @@ def deploy( container_startup_health_check_timeout=container_startup_health_check_timeout, inference_recommendation_id=inference_recommendation_id, explainer_config=explainer_config, + update_endpoint=update_endpoint, **kwargs, ) diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index e5ea1ea314..24be862ecc 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -53,7 +53,6 @@ from sagemaker.model_card.schema_constraints import ModelApprovalStatusEnum from sagemaker.session import Session from sagemaker.model_metrics import ModelMetrics -from sagemaker.deprecations import removed_kwargs from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.explainer import ExplainerConfig from sagemaker.metadata_properties import MetadataProperties @@ -1386,6 +1385,7 @@ def deploy( routing_config: Optional[Dict[str, Any]] = None, model_reference_arn: Optional[str] = None, inference_ami_version: Optional[str] = None, + update_endpoint: Optional[bool] = False, **kwargs, ): """Deploy this ``Model`` to an ``Endpoint`` and optionally return a ``Predictor``. @@ -1497,6 +1497,10 @@ def deploy( inference_ami_version (Optional [str]): Specifies an option from a collection of preconfigured Amazon Machine Image (AMI) images. For a full list of options, see: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ProductionVariant.html + update_endpoint (Optional[bool]): Flag to update the model in an existing Amazon SageMaker endpoint. + If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources + corresponding to the previous EndpointConfig. Default: False + Note: Currently this is supported for single model endpoints Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or @@ -1512,8 +1516,6 @@ def deploy( """ self.accept_eula = accept_eula - removed_kwargs("update_endpoint", kwargs) - self._init_sagemaker_session_if_does_not_exist(instance_type) # Depending on the instance type, a local session (or) a session is initialized. self.role = resolve_value_from_config( @@ -1628,6 +1630,8 @@ def deploy( # Support multiple models on same endpoint if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED: + if update_endpoint: + raise ValueError("Currently update_endpoint is supported for single model endpoints") if endpoint_name: self.endpoint_name = endpoint_name else: @@ -1783,17 +1787,38 @@ def deploy( if is_explainer_enabled: explainer_config_dict = explainer_config._to_request_dict() - self.sagemaker_session.endpoint_from_production_variants( - name=self.endpoint_name, - production_variants=[production_variant], - tags=tags, - kms_key=kms_key, - wait=wait, - data_capture_config_dict=data_capture_config_dict, - explainer_config_dict=explainer_config_dict, - async_inference_config_dict=async_inference_config_dict, - live_logging=endpoint_logging, - ) + if update_endpoint: + endpoint_config_name = self.sagemaker_session.create_endpoint_config( + name=self.name, + model_name=self.name, + initial_instance_count=initial_instance_count, + instance_type=instance_type, + accelerator_type=accelerator_type, + tags=tags, + kms_key=kms_key, + data_capture_config_dict=data_capture_config_dict, + volume_size=volume_size, + model_data_download_timeout=model_data_download_timeout, + container_startup_health_check_timeout=container_startup_health_check_timeout, + explainer_config_dict=explainer_config_dict, + async_inference_config_dict=async_inference_config_dict, + serverless_inference_config=serverless_inference_config_dict, + routing_config=routing_config, + inference_ami_version=inference_ami_version, + ) + self.sagemaker_session.update_endpoint(self.endpoint_name, endpoint_config_name) + else: + self.sagemaker_session.endpoint_from_production_variants( + name=self.endpoint_name, + production_variants=[production_variant], + tags=tags, + kms_key=kms_key, + wait=wait, + data_capture_config_dict=data_capture_config_dict, + explainer_config_dict=explainer_config_dict, + async_inference_config_dict=async_inference_config_dict, + live_logging=endpoint_logging, + ) if self.predictor_cls: predictor = self.predictor_cls(self.endpoint_name, self.sagemaker_session) diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index a7a518105c..1a4746a6db 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -1602,6 +1602,7 @@ def deploy( ResourceRequirements, ] ] = None, + update_endpoint: Optional[bool] = False, ) -> Union[Predictor, Transformer]: """Deploys the built Model. @@ -1615,24 +1616,32 @@ def deploy( AsyncInferenceConfig, BatchTransformInferenceConfig, ResourceRequirements]]) : Additional Config for different deployment types such as serverless, async, batch and multi-model/container + update_endpoint (Optional[bool]): Flag to update the model in an existing Amazon SageMaker endpoint. + If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources + corresponding to the previous EndpointConfig. Default: False + Note: Currently this is supported for single model endpoints Returns: Transformer for Batch Deployments Predictors for all others """ if not hasattr(self, "built_model"): raise ValueError("Model Needs to be built before deploying") - endpoint_name = unique_name_from_base(endpoint_name) + if not update_endpoint: + endpoint_name = unique_name_from_base(endpoint_name) + if not inference_config: # Real-time Deployment return self.built_model.deploy( instance_type=self.instance_type, initial_instance_count=initial_instance_count, endpoint_name=endpoint_name, + update_endpoint=update_endpoint, ) if isinstance(inference_config, ServerlessInferenceConfig): return self.built_model.deploy( serverless_inference_config=inference_config, endpoint_name=endpoint_name, + update_endpoint=update_endpoint, ) if isinstance(inference_config, AsyncInferenceConfig): @@ -1641,6 +1650,7 @@ def deploy( initial_instance_count=initial_instance_count, async_inference_config=inference_config, endpoint_name=endpoint_name, + update_endpoint=update_endpoint, ) if isinstance(inference_config, BatchTransformInferenceConfig): @@ -1652,6 +1662,8 @@ def deploy( return transformer if isinstance(inference_config, ResourceRequirements): + if update_endpoint: + raise ValueError("Currently update_endpoint is supported for single model endpoints") # Multi Model and MultiContainer endpoints with Inference Component return self.built_model.deploy( instance_type=self.instance_type, @@ -1660,6 +1672,7 @@ def deploy( resources=inference_config, initial_instance_count=initial_instance_count, role=self.role_arn, + update_endpoint=update_endpoint, ) raise ValueError("Deployment Options not supported") diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index b2398e03d1..4a65b3ccb5 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4488,6 +4488,10 @@ def create_endpoint_config( model_data_download_timeout=None, container_startup_health_check_timeout=None, explainer_config_dict=None, + async_inference_config_dict=None, + serverless_inference_config_dict=None, + routing_config: Optional[Dict[str, Any]] = None, + inference_ami_version: Optional[str] = None, ): """Create an Amazon SageMaker endpoint configuration. @@ -4525,6 +4529,27 @@ def create_endpoint_config( -inference-algo-ping-requests explainer_config_dict (dict): Specifies configuration to enable explainers. Default: None. + async_inference_config_dict (dict): Specifies + configuration related to async endpoint. Use this configuration when trying + to create async endpoint and make async inference. If empty config object + passed through, will use default config to deploy async endpoint. Deploy a + real-time endpoint if it's None. (default: None). + serverless_inference_config_dict (dict): + Specifies configuration related to serverless endpoint. Use this configuration + when trying to create serverless endpoint and make serverless inference. If + empty object passed through, will use pre-defined values in + ``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an + instance based endpoint if it's None. (default: None). + routing_config (Optional[Dict[str, Any]): Settings the control how the endpoint routes incoming + traffic to the instances that the endpoint hosts. + Currently, support dictionary key ``RoutingStrategy``. + + .. code:: python + + { + "RoutingStrategy": sagemaker.enums.RoutingStrategy.RANDOM + } + inference_ami_version (Optional [str]): Specifies an option from a collection of preconfigured Example: >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}] @@ -4544,9 +4569,12 @@ def create_endpoint_config( instance_type, initial_instance_count, accelerator_type=accelerator_type, + serverless_inference_config=serverless_inference_config_dict, volume_size=volume_size, model_data_download_timeout=model_data_download_timeout, container_startup_health_check_timeout=container_startup_health_check_timeout, + routing_config=routing_config, + inference_ami_version=inference_ami_version, ) production_variants = [provided_production_variant] # Currently we just inject CoreDumpConfig.KmsKeyId from the config for production variant. @@ -4586,6 +4614,14 @@ def create_endpoint_config( ) request["DataCaptureConfig"] = inferred_data_capture_config_dict + if async_inference_config_dict is not None: + inferred_async_inference_config_dict = update_nested_dictionary_with_values_from_config( + async_inference_config_dict, + ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH, + sagemaker_session=self, + ) + request["AsyncInferenceConfig"] = inferred_async_inference_config_dict + if explainer_config_dict is not None: request["ExplainerConfig"] = explainer_config_dict diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index c7f624114f..b384cbbbb5 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -358,6 +358,7 @@ def deploy( container_startup_health_check_timeout=None, inference_recommendation_id=None, explainer_config=None, + update_endpoint: Optional[bool] = False, **kwargs, ): """Deploy a Tensorflow ``Model`` to a SageMaker ``Endpoint``.""" @@ -383,6 +384,7 @@ def deploy( container_startup_health_check_timeout=container_startup_health_check_timeout, inference_recommendation_id=inference_recommendation_id, explainer_config=explainer_config, + update_endpoint=update_endpoint, **kwargs, ) diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 7b99281b96..49a67871e7 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -23,6 +23,7 @@ from sagemaker.serverless import ServerlessInferenceConfig from sagemaker.explainer import ExplainerConfig from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements +from sagemaker.enums import EndpointType from tests.unit.sagemaker.inference_recommender.constants import ( DESCRIBE_COMPILATION_JOB_RESPONSE, DESCRIBE_MODEL_PACKAGE_RESPONSE, @@ -1051,3 +1052,142 @@ def test_deploy_with_name_and_resources(sagemaker_session): async_inference_config_dict=None, live_logging=False, ) + + +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +@patch("sagemaker.utils.name_from_base", return_value=ENDPOINT_NAME) +@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) +def test_deploy_with_update_endpoint(production_variant, name_from_base, sagemaker_session): + model = Model( + MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session + ) + + # Mock the create_endpoint_config to return a specific config name + endpoint_config_name = "test-config-name" + sagemaker_session.create_endpoint_config.return_value = endpoint_config_name + + # Test update_endpoint=True scenario + endpoint_name = "existing-endpoint" + model.deploy( + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + endpoint_name=endpoint_name, + update_endpoint=True, + ) + + # Verify create_endpoint_config is called with correct parameters + sagemaker_session.create_endpoint_config.assert_called_with( + name=MODEL_NAME, + model_name=MODEL_NAME, + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=None, + tags=None, + kms_key=None, + data_capture_config_dict=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, + explainer_config_dict=None, + async_inference_config_dict=None, + serverless_inference_config=None, + routing_config=None, + inference_ami_version=None, + ) + + # Verify update_endpoint is called with correct parameters + sagemaker_session.update_endpoint.assert_called_with(endpoint_name, endpoint_config_name) + + # Test update_endpoint with serverless config + serverless_inference_config = ServerlessInferenceConfig() + serverless_inference_config_dict = { + "MemorySizeInMB": 2048, + "MaxConcurrency": 5, + } + model.deploy( + endpoint_name=endpoint_name, + update_endpoint=True, + serverless_inference_config=serverless_inference_config, + ) + + sagemaker_session.create_endpoint_config.assert_called_with( + name=MODEL_NAME, + model_name=MODEL_NAME, + initial_instance_count=None, + instance_type=None, + accelerator_type=None, + tags=None, + kms_key=None, + data_capture_config_dict=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, + explainer_config_dict=None, + async_inference_config_dict=None, + serverless_inference_config=serverless_inference_config_dict, + routing_config=None, + inference_ami_version=None, + ) + + # Verify update_endpoint is called with the new config + sagemaker_session.update_endpoint.assert_called_with(endpoint_name, endpoint_config_name) + + # Test update_endpoint with async inference config + async_inference_config = AsyncInferenceConfig( + output_path="s3://bucket/output", + failure_path="s3://bucket/failure" + ) + async_inference_config_dict = { + "OutputConfig": { + "S3OutputPath": "s3://bucket/output", + "S3FailurePath": "s3://bucket/failure" + }, + } + model.deploy( + endpoint_name=endpoint_name, + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + update_endpoint=True, + async_inference_config=async_inference_config, + ) + + sagemaker_session.create_endpoint_config.assert_called_with( + name=MODEL_NAME, + model_name=MODEL_NAME, + initial_instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + accelerator_type=None, + tags=None, + kms_key=None, + data_capture_config_dict=None, + volume_size=None, + model_data_download_timeout=None, + container_startup_health_check_timeout=None, + explainer_config_dict=None, + async_inference_config_dict=async_inference_config_dict, + serverless_inference_config=None, + routing_config=None, + inference_ami_version=None, + ) + + # Verify update_endpoint is called with the new config + sagemaker_session.update_endpoint.assert_called_with(endpoint_name, endpoint_config_name) + + +@patch("sagemaker.model.Model._create_sagemaker_model", Mock()) +@patch("sagemaker.production_variant", return_value=BASE_PRODUCTION_VARIANT) +def test_deploy_with_update_endpoint_inference_component(production_variant, sagemaker_session): + model = Model( + MODEL_IMAGE, MODEL_DATA, role=ROLE, name=MODEL_NAME, sagemaker_session=sagemaker_session + ) + + # Test that updating endpoint with inference component raises error + with pytest.raises(ValueError, match="Currently update_endpoint is supported for single model endpoints"): + model.deploy( + endpoint_name="test-endpoint", + instance_type=INSTANCE_TYPE, + initial_instance_count=INSTANCE_COUNT, + update_endpoint=True, + resources=RESOURCES, + endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, + ) diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index 107d65c301..ee37bf7b43 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -4041,14 +4041,30 @@ def test_neuron_configurations_rule_set(self): @pytest.mark.parametrize( "test_case", [ + # Real-time deployment without update { "input_args": {"endpoint_name": "test"}, "call_params": { "instance_type": "ml.g5.2xlarge", "initial_instance_count": 1, "endpoint_name": "test", + "update_endpoint": False, }, }, + # Real-time deployment with update + { + "input_args": { + "endpoint_name": "existing-endpoint", + "update_endpoint": True, + }, + "call_params": { + "instance_type": "ml.g5.2xlarge", + "initial_instance_count": 1, + "endpoint_name": "existing-endpoint", + "update_endpoint": True, + }, + }, + # Serverless deployment without update { "input_args": { "endpoint_name": "test", @@ -4057,8 +4073,23 @@ def test_neuron_configurations_rule_set(self): "call_params": { "serverless_inference_config": ServerlessInferenceConfig(), "endpoint_name": "test", + "update_endpoint": False, + }, + }, + # Serverless deployment with update + { + "input_args": { + "endpoint_name": "existing-endpoint", + "inference_config": ServerlessInferenceConfig(), + "update_endpoint": True, + }, + "call_params": { + "serverless_inference_config": ServerlessInferenceConfig(), + "endpoint_name": "existing-endpoint", + "update_endpoint": True, }, }, + # Async deployment without update { "input_args": { "endpoint_name": "test", @@ -4069,10 +4100,30 @@ def test_neuron_configurations_rule_set(self): "instance_type": "ml.g5.2xlarge", "initial_instance_count": 1, "endpoint_name": "test", + "update_endpoint": False, + }, + }, + # Async deployment with update + { + "input_args": { + "endpoint_name": "existing-endpoint", + "inference_config": AsyncInferenceConfig(output_path="op-path"), + "update_endpoint": True, + }, + "call_params": { + "async_inference_config": AsyncInferenceConfig(output_path="op-path"), + "instance_type": "ml.g5.2xlarge", + "initial_instance_count": 1, + "endpoint_name": "existing-endpoint", + "update_endpoint": True, }, }, + # Multi-Model deployment (update_endpoint not supported) { - "input_args": {"endpoint_name": "test", "inference_config": RESOURCE_REQUIREMENTS}, + "input_args": { + "endpoint_name": "test", + "inference_config": RESOURCE_REQUIREMENTS, + }, "call_params": { "resources": RESOURCE_REQUIREMENTS, "role": "role-arn", @@ -4080,12 +4131,16 @@ def test_neuron_configurations_rule_set(self): "instance_type": "ml.g5.2xlarge", "mode": Mode.SAGEMAKER_ENDPOINT, "endpoint_type": EndpointType.INFERENCE_COMPONENT_BASED, + "update_endpoint": False, }, }, + # Batch transform { "input_args": { "inference_config": BatchTransformInferenceConfig( - instance_count=1, instance_type="ml.m5.large", output_path="op-path" + instance_count=1, + instance_type="ml.m5.large", + output_path="op-path" ) }, "call_params": { @@ -4096,7 +4151,16 @@ def test_neuron_configurations_rule_set(self): "id": "Batch", }, ], - ids=["Real Time", "Serverless", "Async", "Multi-Model", "Batch"], + ids=[ + "Real Time", + "Real Time Update", + "Serverless", + "Serverless Update", + "Async", + "Async Update", + "Multi-Model", + "Batch", + ], ) @patch("sagemaker.serve.builder.model_builder.unique_name_from_base") def test_deploy(mock_unique_name_from_base, test_case): @@ -4119,3 +4183,20 @@ def test_deploy(mock_unique_name_from_base, test_case): diff = deepdiff.DeepDiff(kwargs, test_case["call_params"]) assert diff == {} + + +def test_deploy_multi_model_update_error(): + model_builder = ModelBuilder( + model="meta-llama/Meta-Llama-3-8B-Instruct", + env_vars={"HUGGING_FACE_HUB_TOKEN": "token"}, + role_arn="role-arn", + instance_type="ml.g5.2xlarge", + ) + setattr(model_builder, "built_model", MagicMock()) + + with pytest.raises(ValueError, match="Currently update_endpoint is supported for single model endpoints"): + model_builder.deploy( + endpoint_name="test", + inference_config=RESOURCE_REQUIREMENTS, + update_endpoint=True + ) From 0bf6404dc6678a622f5c04b9babc51b2dd41a3c1 Mon Sep 17 00:00:00 2001 From: Roja Reddy Sareddy Date: Wed, 12 Mar 2025 10:13:04 -0700 Subject: [PATCH 08/11] fix: fix unit test, black-check, pylint errors --- src/sagemaker/huggingface/model.py | 4 ++-- src/sagemaker/model.py | 4 +++- src/sagemaker/serve/builder/model_builder.py | 4 +++- src/sagemaker/session.py | 4 ++-- tests/unit/sagemaker/jumpstart/model/test_model.py | 2 +- tests/unit/sagemaker/model/test_deploy.py | 9 +++++---- .../sagemaker/serve/builder/test_model_builder.py | 12 +++++------- 7 files changed, 21 insertions(+), 18 deletions(-) diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index f4b44fc057..6ef28f99e5 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -298,8 +298,8 @@ def deploy( explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability configuration for use with Amazon SageMaker Clarify. (default: None) update_endpoint (Optional[bool]): Flag to update the model in an existing Amazon SageMaker endpoint. - If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources - corresponding to the previous EndpointConfig. Default: False + If True, this will deploy a new EndpointConfig to an already existing endpoint and + delete resources corresponding to the previous EndpointConfig. Default: False Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 24be862ecc..74177795a7 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1631,7 +1631,9 @@ def deploy( # Support multiple models on same endpoint if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED: if update_endpoint: - raise ValueError("Currently update_endpoint is supported for single model endpoints") + raise ValueError( + "Currently update_endpoint is supported for single model endpoints" + ) if endpoint_name: self.endpoint_name = endpoint_name else: diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 1a4746a6db..4a49c0538e 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -1663,7 +1663,9 @@ def deploy( if isinstance(inference_config, ResourceRequirements): if update_endpoint: - raise ValueError("Currently update_endpoint is supported for single model endpoints") + raise ValueError( + "Currently update_endpoint is supported for single model endpoints" + ) # Multi Model and MultiContainer endpoints with Inference Component return self.built_model.deploy( instance_type=self.instance_type, diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 4a65b3ccb5..be0bb9b688 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4540,8 +4540,8 @@ def create_endpoint_config( empty object passed through, will use pre-defined values in ``ServerlessInferenceConfig`` class to deploy serverless endpoint. Deploy an instance based endpoint if it's None. (default: None). - routing_config (Optional[Dict[str, Any]): Settings the control how the endpoint routes incoming - traffic to the instances that the endpoint hosts. + routing_config (Optional[Dict[str, Any]): Settings the control how the endpoint routes + incoming traffic to the instances that the endpoint hosts. Currently, support dictionary key ``RoutingStrategy``. .. code:: python diff --git a/tests/unit/sagemaker/jumpstart/model/test_model.py b/tests/unit/sagemaker/jumpstart/model/test_model.py index be961828f4..d9b126f651 100644 --- a/tests/unit/sagemaker/jumpstart/model/test_model.py +++ b/tests/unit/sagemaker/jumpstart/model/test_model.py @@ -794,7 +794,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self): and reach out to JumpStart team.""" init_args_to_skip: Set[str] = set(["model_reference_arn"]) - deploy_args_to_skip: Set[str] = set(["kwargs", "model_reference_arn"]) + deploy_args_to_skip: Set[str] = set(["kwargs", "model_reference_arn", "update_endpoint"]) deploy_args_removed_at_deploy_time: Set[str] = set(["model_access_configs"]) parent_class_init = Model.__init__ diff --git a/tests/unit/sagemaker/model/test_deploy.py b/tests/unit/sagemaker/model/test_deploy.py index 49a67871e7..4167ca62c3 100644 --- a/tests/unit/sagemaker/model/test_deploy.py +++ b/tests/unit/sagemaker/model/test_deploy.py @@ -1134,13 +1134,12 @@ def test_deploy_with_update_endpoint(production_variant, name_from_base, sagemak # Test update_endpoint with async inference config async_inference_config = AsyncInferenceConfig( - output_path="s3://bucket/output", - failure_path="s3://bucket/failure" + output_path="s3://bucket/output", failure_path="s3://bucket/failure" ) async_inference_config_dict = { "OutputConfig": { "S3OutputPath": "s3://bucket/output", - "S3FailurePath": "s3://bucket/failure" + "S3FailurePath": "s3://bucket/failure", }, } model.deploy( @@ -1182,7 +1181,9 @@ def test_deploy_with_update_endpoint_inference_component(production_variant, sag ) # Test that updating endpoint with inference component raises error - with pytest.raises(ValueError, match="Currently update_endpoint is supported for single model endpoints"): + with pytest.raises( + ValueError, match="Currently update_endpoint is supported for single model endpoints" + ): model.deploy( endpoint_name="test-endpoint", instance_type=INSTANCE_TYPE, diff --git a/tests/unit/sagemaker/serve/builder/test_model_builder.py b/tests/unit/sagemaker/serve/builder/test_model_builder.py index ee37bf7b43..6661c6e2bf 100644 --- a/tests/unit/sagemaker/serve/builder/test_model_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_model_builder.py @@ -4138,9 +4138,7 @@ def test_neuron_configurations_rule_set(self): { "input_args": { "inference_config": BatchTransformInferenceConfig( - instance_count=1, - instance_type="ml.m5.large", - output_path="op-path" + instance_count=1, instance_type="ml.m5.large", output_path="op-path" ) }, "call_params": { @@ -4194,9 +4192,9 @@ def test_deploy_multi_model_update_error(): ) setattr(model_builder, "built_model", MagicMock()) - with pytest.raises(ValueError, match="Currently update_endpoint is supported for single model endpoints"): + with pytest.raises( + ValueError, match="Currently update_endpoint is supported for single model endpoints" + ): model_builder.deploy( - endpoint_name="test", - inference_config=RESOURCE_REQUIREMENTS, - update_endpoint=True + endpoint_name="test", inference_config=RESOURCE_REQUIREMENTS, update_endpoint=True ) From c67d7df4f9a1cc15aff8a2cf9dad277f05355dc3 Mon Sep 17 00:00:00 2001 From: Roja Reddy Sareddy Date: Wed, 12 Mar 2025 10:51:19 -0700 Subject: [PATCH 09/11] fix: fix black-check, pylint errors --- src/sagemaker/huggingface/model.py | 8 +++++--- src/sagemaker/model.py | 7 ++++--- src/sagemaker/serve/builder/model_builder.py | 7 ++++--- src/sagemaker/session.py | 5 ++++- 4 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index 6ef28f99e5..3ca25fb3ce 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -297,9 +297,11 @@ def deploy( would like to deploy the model and endpoint with recommended parameters. explainer_config (sagemaker.explainer.ExplainerConfig): Specifies online explainability configuration for use with Amazon SageMaker Clarify. (default: None) - update_endpoint (Optional[bool]): Flag to update the model in an existing Amazon SageMaker endpoint. - If True, this will deploy a new EndpointConfig to an already existing endpoint and - delete resources corresponding to the previous EndpointConfig. Default: False + update_endpoint (Optional[bool]): + Flag to update the model in an existing Amazon SageMaker endpoint. + If True, this will deploy a new EndpointConfig to an already existing endpoint + and delete resources corresponding to the previous EndpointConfig. Default: False + Note: Currently this is supported for single model endpoints Raises: ValueError: If arguments combination check failed in these circumstances: - If no role is specified or diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index 74177795a7..b281d9f489 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -1497,9 +1497,10 @@ def deploy( inference_ami_version (Optional [str]): Specifies an option from a collection of preconfigured Amazon Machine Image (AMI) images. For a full list of options, see: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ProductionVariant.html - update_endpoint (Optional[bool]): Flag to update the model in an existing Amazon SageMaker endpoint. - If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources - corresponding to the previous EndpointConfig. Default: False + update_endpoint (Optional[bool]): + Flag to update the model in an existing Amazon SageMaker endpoint. + If True, this will deploy a new EndpointConfig to an already existing endpoint + and delete resources corresponding to the previous EndpointConfig. Default: False Note: Currently this is supported for single model endpoints Raises: ValueError: If arguments combination check failed in these circumstances: diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 4a49c0538e..9122f22e44 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -1616,9 +1616,10 @@ def deploy( AsyncInferenceConfig, BatchTransformInferenceConfig, ResourceRequirements]]) : Additional Config for different deployment types such as serverless, async, batch and multi-model/container - update_endpoint (Optional[bool]): Flag to update the model in an existing Amazon SageMaker endpoint. - If True, this will deploy a new EndpointConfig to an already existing endpoint and delete resources - corresponding to the previous EndpointConfig. Default: False + update_endpoint (Optional[bool]): + Flag to update the model in an existing Amazon SageMaker endpoint. + If True, this will deploy a new EndpointConfig to an already existing endpoint + and delete resources corresponding to the previous EndpointConfig. Default: False Note: Currently this is supported for single model endpoints Returns: Transformer for Batch Deployments diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index be0bb9b688..38fa7f8c26 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4549,7 +4549,10 @@ def create_endpoint_config( { "RoutingStrategy": sagemaker.enums.RoutingStrategy.RANDOM } - inference_ami_version (Optional [str]): Specifies an option from a collection of preconfigured + inference_ami_version (Optional [str]): + Specifies an option from a collection of preconfigured + Amazon Machine Image (AMI) images. For a full list of options, see: + https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_ProductionVariant.html Example: >>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}] From 89e18a9f3e75843e3074bf2f07221f9ec67ef5ea Mon Sep 17 00:00:00 2001 From: Roja Reddy Sareddy Date: Mon, 7 Apr 2025 23:06:31 -0700 Subject: [PATCH 10/11] fix:Added handler for pipeline variable while creating process job --- src/sagemaker/processing.py | 43 +++++- tests/unit/test_processing.py | 255 +++++++++++++++++++++++++++++++++- 2 files changed, 296 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index d8674f269d..1fdfaae3c2 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -60,9 +60,10 @@ ) from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable -from sagemaker.workflow.execution_variables import ExecutionVariables +from sagemaker.workflow.execution_variables import ExecutionVariable, ExecutionVariables from sagemaker.workflow.functions import Join from sagemaker.workflow.pipeline_context import runnable_by_pipeline +from sagemaker.workflow.parameters import Parameter logger = logging.getLogger(__name__) @@ -314,6 +315,15 @@ def _normalize_args( "code argument has to be a valid S3 URI or local file path " + "rather than a pipeline variable" ) + if arguments is not None: + normalized_arguments = [] + for arg in arguments: + if isinstance(arg, PipelineVariable): + normalized_value = self._normalize_pipeline_variable(arg) + normalized_arguments.append(normalized_value) + else: + normalized_arguments.append(str(arg)) + arguments = normalized_arguments self._current_job_name = self._generate_current_job_name(job_name=job_name) @@ -499,6 +509,37 @@ def _normalize_outputs(self, outputs=None): normalized_outputs.append(output) return normalized_outputs + def _normalize_pipeline_variable(self, value): + """Helper function to normalize PipelineVariable objects""" + try: + if isinstance(value, Parameter): + return str(value.default_value) if value.default_value is not None else None + + elif isinstance(value, ExecutionVariable): + return f"{value.name}" + + elif isinstance(value, Join): + normalized_values = [ + normalize_pipeline_variable(v) if isinstance(v, PipelineVariable) else str(v) + for v in value.values + ] + return value.on.join(normalized_values) + + elif isinstance(value, PipelineVariable): + if hasattr(value, 'default_value'): + return str(value.default_value) + elif hasattr(value, 'expr'): + return str(value.expr) + + return str(value) + + except AttributeError as e: + raise ValueError(f"Missing required attribute while normalizing {type(value).__name__}: {e}") + except TypeError as e: + raise ValueError(f"Type error while normalizing {type(value).__name__}: {e}") + except Exception as e: + raise ValueError(f"Error normalizing {type(value).__name__}: {e}") + class ScriptProcessor(Processor): """Handles Amazon SageMaker processing tasks for jobs using a machine learning framework.""" diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 06d2cde02e..0088e10640 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -46,8 +46,9 @@ from sagemaker.fw_utils import UploadedCode from sagemaker.workflow.pipeline_context import PipelineSession, _PipelineConfig from sagemaker.workflow.functions import Join -from sagemaker.workflow.execution_variables import ExecutionVariables +from sagemaker.workflow.execution_variables import ExecutionVariable, ExecutionVariables from tests.unit import SAGEMAKER_CONFIG_PROCESSING_JOB +from sagemaker.workflow.parameters import ParameterString, Parameter BUCKET_NAME = "mybucket" REGION = "us-west-2" @@ -1717,3 +1718,255 @@ def _get_describe_response_inputs_and_ouputs(): "ProcessingInputs": _get_expected_args_all_parameters(None)["inputs"], "ProcessingOutputConfig": _get_expected_args_all_parameters(None)["output_config"], } + +# Parameters +def _get_data_inputs_with_parameters(): + return [ + ProcessingInput( + source=ParameterString( + name="input_data", + default_value="s3://dummy-bucket/input" + ), + destination="/opt/ml/processing/input", + input_name="input-1" + ) + ] + + +def _get_data_outputs_with_parameters(): + return [ + ProcessingOutput( + source="/opt/ml/processing/output", + destination=ParameterString( + name="output_data", + default_value="s3://dummy-bucket/output" + ), + output_name="output-1" + ) + ] + + +def _get_expected_args_with_parameters(job_name): + return { + "inputs": [{ + "InputName": "input-1", + "S3Input": { + "S3Uri": "s3://dummy-bucket/input", + "LocalPath": "/opt/ml/processing/input", + "S3DataType": "S3Prefix", + "S3InputMode": "File", + "S3DataDistributionType": "FullyReplicated", + "S3CompressionType": "None" + } + }], + "output_config": { + "Outputs": [{ + "OutputName": "output-1", + "S3Output": { + "S3Uri": "s3://dummy-bucket/output", + "LocalPath": "/opt/ml/processing/output", + "S3UploadMode": "EndOfJob" + } + }] + }, + "job_name": job_name, + "resources": { + "ClusterConfig": { + "InstanceType": "ml.m4.xlarge", + "InstanceCount": 1, + "VolumeSizeInGB": 100, + "VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key" + } + }, + "stopping_condition": {"MaxRuntimeInSeconds": 3600}, + "app_specification": { + "ImageUri": "custom-image-uri", + "ContainerArguments": [ + "--input-data", + "s3://dummy-bucket/input-param", + "--output-path", + "s3://dummy-bucket/output-param" + ], + "ContainerEntrypoint": ["python3"] + }, + "environment": {"my_env_variable": "my_env_variable_value"}, + "network_config": { + "EnableNetworkIsolation": True, + "EnableInterContainerTrafficEncryption": True, + "VpcConfig": { + "Subnets": ["my_subnet_id"], + "SecurityGroupIds": ["my_security_group_id"] + } + }, + "role_arn": "dummy/role", + "tags": [{"Key": "my-tag", "Value": "my-tag-value"}], + "experiment_config": {"ExperimentName": "AnExperiment"} + } + + +@patch("os.path.exists", return_value=True) +@patch("os.path.isfile", return_value=True) +@patch("sagemaker.utils.repack_model") +@patch("sagemaker.utils.create_tar_file") +@patch("sagemaker.session.Session.upload_data") +def test_script_processor_with_parameter_string( + upload_data_mock, + create_tar_file_mock, + repack_model_mock, + exists_mock, + isfile_mock, + sagemaker_session, +): + """Test ScriptProcessor with ParameterString arguments""" + upload_data_mock.return_value = "s3://mocked_s3_uri_from_upload_data" + + # Setup processor + processor = ScriptProcessor( + role="arn:aws:iam::012345678901:role/SageMakerRole", # Updated role ARN + image_uri="custom-image-uri", + command=["python3"], + instance_type="ml.m4.xlarge", + instance_count=1, + volume_size_in_gb=100, + volume_kms_key="arn:aws:kms:us-west-2:012345678901:key/volume-kms-key", + output_kms_key="arn:aws:kms:us-west-2:012345678901:key/output-kms-key", + max_runtime_in_seconds=3600, + base_job_name="test_processor", + env={"my_env_variable": "my_env_variable_value"}, + tags=[{"Key": "my-tag", "Value": "my-tag-value"}], + network_config=NetworkConfig( + subnets=["my_subnet_id"], + security_group_ids=["my_security_group_id"], + enable_network_isolation=True, + encrypt_inter_container_traffic=True, + ), + sagemaker_session=sagemaker_session, + ) + + input_param = ParameterString( + name="input_param", + default_value="s3://dummy-bucket/input-param" + ) + output_param = ParameterString( + name="output_param", + default_value="s3://dummy-bucket/output-param" + ) + exec_var = ExecutionVariable( + name="ExecutionTest" + ) + join_var = Join( + on="/", + values=["s3://bucket", "prefix", "file.txt"] + ) + dummy_str_var = "test-variable" + + # Define expected arguments + expected_args = { + "inputs": [ + { + "InputName": "input-1", + "AppManaged": False, + "S3Input": { + "S3Uri": ParameterString( + name="input_data", + default_value="s3://dummy-bucket/input" + ), + "LocalPath": "/opt/ml/processing/input", + "S3DataType": "S3Prefix", + "S3InputMode": "File", + "S3DataDistributionType": "FullyReplicated", + "S3CompressionType": "None" + } + }, + { + "InputName": "code", + "AppManaged": False, + "S3Input": { + "S3Uri": "s3://mocked_s3_uri_from_upload_data", + "LocalPath": "/opt/ml/processing/input/code", + "S3DataType": "S3Prefix", + "S3InputMode": "File", + "S3DataDistributionType": "FullyReplicated", + "S3CompressionType": "None" + } + } + ], + "output_config": { + "Outputs": [ + { + "OutputName": "output-1", + "AppManaged": False, + "S3Output": { + "S3Uri": ParameterString( + name="output_data", + default_value="s3://dummy-bucket/output" + ), + "LocalPath": "/opt/ml/processing/output", + "S3UploadMode": "EndOfJob" + } + } + ], + "KmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/output-kms-key" + }, + "job_name": "test_job", + "resources": { + "ClusterConfig": { + "InstanceType": "ml.m4.xlarge", + "InstanceCount": 1, + "VolumeSizeInGB": 100, + "VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key" + } + }, + "stopping_condition": {"MaxRuntimeInSeconds": 3600}, + "app_specification": { + "ImageUri": "custom-image-uri", + "ContainerArguments": [ + "--input-data", + "s3://dummy-bucket/input-param", + "--output-path", + "s3://dummy-bucket/output-param", + "--exec-arg", "ExecutionTest", + "--join-arg", "s3://bucket/prefix/file.txt", + "--string-param", "test-variable" + ], + "ContainerEntrypoint": ["python3", "/opt/ml/processing/input/code/processing_code.py"] + }, + "environment": {"my_env_variable": "my_env_variable_value"}, + "network_config": { + "EnableNetworkIsolation": True, + "EnableInterContainerTrafficEncryption": True, + "VpcConfig": { + "SecurityGroupIds": ["my_security_group_id"], + "Subnets": ["my_subnet_id"] + } + }, + "role_arn": "arn:aws:iam::012345678901:role/SageMakerRole", + "tags": [{"Key": "my-tag", "Value": "my-tag-value"}], + "experiment_config": {"ExperimentName": "AnExperiment"} + } + + # Run processor + processor.run( + code="/local/path/to/processing_code.py", + inputs=_get_data_inputs_with_parameters(), + outputs=_get_data_outputs_with_parameters(), + arguments=[ + "--input-data", + input_param, + "--output-path", + output_param, + "--exec-arg", exec_var, + "--join-arg", join_var, + "--string-param", dummy_str_var + ], + wait=True, + logs=False, + job_name="test_job", + experiment_config={"ExperimentName": "AnExperiment"}, + ) + + # Assert + sagemaker_session.process.assert_called_with(**expected_args) + assert "test_job" in processor._current_job_name + + From 7f15e1928037bf34b60a8697ac5f1450c73abae1 Mon Sep 17 00:00:00 2001 From: Roja Reddy Sareddy Date: Wed, 9 Apr 2025 15:18:05 -0700 Subject: [PATCH 11/11] fix: Added handler for pipeline variable while creating process job --- src/sagemaker/processing.py | 46 +----- .../workflow/test_processing_step.py | 17 +- tests/unit/test_processing.py | 150 +++++++++--------- 3 files changed, 93 insertions(+), 120 deletions(-) diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 1fdfaae3c2..7beef2e5bd 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -17,7 +17,7 @@ and interpretation on Amazon SageMaker. """ from __future__ import absolute_import - +import json import logging import os import pathlib @@ -60,10 +60,9 @@ ) from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable -from sagemaker.workflow.execution_variables import ExecutionVariable, ExecutionVariables +from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.functions import Join from sagemaker.workflow.pipeline_context import runnable_by_pipeline -from sagemaker.workflow.parameters import Parameter logger = logging.getLogger(__name__) @@ -316,14 +315,14 @@ def _normalize_args( + "rather than a pipeline variable" ) if arguments is not None: - normalized_arguments = [] + processed_arguments = [] for arg in arguments: if isinstance(arg, PipelineVariable): - normalized_value = self._normalize_pipeline_variable(arg) - normalized_arguments.append(normalized_value) + processed_value = json.dumps(arg.expr) + processed_arguments.append(processed_value) else: - normalized_arguments.append(str(arg)) - arguments = normalized_arguments + processed_arguments.append(str(arg)) + arguments = processed_arguments self._current_job_name = self._generate_current_job_name(job_name=job_name) @@ -509,37 +508,6 @@ def _normalize_outputs(self, outputs=None): normalized_outputs.append(output) return normalized_outputs - def _normalize_pipeline_variable(self, value): - """Helper function to normalize PipelineVariable objects""" - try: - if isinstance(value, Parameter): - return str(value.default_value) if value.default_value is not None else None - - elif isinstance(value, ExecutionVariable): - return f"{value.name}" - - elif isinstance(value, Join): - normalized_values = [ - normalize_pipeline_variable(v) if isinstance(v, PipelineVariable) else str(v) - for v in value.values - ] - return value.on.join(normalized_values) - - elif isinstance(value, PipelineVariable): - if hasattr(value, 'default_value'): - return str(value.default_value) - elif hasattr(value, 'expr'): - return str(value.expr) - - return str(value) - - except AttributeError as e: - raise ValueError(f"Missing required attribute while normalizing {type(value).__name__}: {e}") - except TypeError as e: - raise ValueError(f"Type error while normalizing {type(value).__name__}: {e}") - except Exception as e: - raise ValueError(f"Error normalizing {type(value).__name__}: {e}") - class ScriptProcessor(Processor): """Handles Amazon SageMaker processing tasks for jobs using a machine learning framework.""" diff --git a/tests/unit/sagemaker/workflow/test_processing_step.py b/tests/unit/sagemaker/workflow/test_processing_step.py index 0dcd7c2495..f94e0791cb 100644 --- a/tests/unit/sagemaker/workflow/test_processing_step.py +++ b/tests/unit/sagemaker/workflow/test_processing_step.py @@ -824,7 +824,12 @@ def test_spark_processor(spark_processor, processing_input, pipeline_session): processor, run_inputs = spark_processor processor.sagemaker_session = pipeline_session processor.role = ROLE - + arguments_output = [ + "--input", + "input-data-uri", + "--output", + '{"Get": "Parameters.MyArgOutput"}', + ] run_inputs["inputs"] = processing_input step_args = processor.run(**run_inputs) @@ -835,7 +840,7 @@ def test_spark_processor(spark_processor, processing_input, pipeline_session): step_args = get_step_args_helper(step_args, "Processing") - assert step_args["AppSpecification"]["ContainerArguments"] == run_inputs["arguments"] + assert step_args["AppSpecification"]["ContainerArguments"] == arguments_output entry_points = step_args["AppSpecification"]["ContainerEntrypoint"] entry_points_expr = [] @@ -1019,6 +1024,12 @@ def test_spark_processor_local_code(spark_processor, processing_input, pipeline_ processor, run_inputs = spark_processor processor.sagemaker_session = pipeline_session processor.role = ROLE + arguments_output = [ + "--input", + "input-data-uri", + "--output", + '{"Get": "Parameters.MyArgOutput"}', + ] run_inputs["inputs"] = processing_input @@ -1030,7 +1041,7 @@ def test_spark_processor_local_code(spark_processor, processing_input, pipeline_ step_args = get_step_args_helper(step_args, "Processing") - assert step_args["AppSpecification"]["ContainerArguments"] == run_inputs["arguments"] + assert step_args["AppSpecification"]["ContainerArguments"] == arguments_output entry_points = step_args["AppSpecification"]["ContainerEntrypoint"] entry_points_expr = [] diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 0088e10640..7b020c61bf 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -48,7 +48,7 @@ from sagemaker.workflow.functions import Join from sagemaker.workflow.execution_variables import ExecutionVariable, ExecutionVariables from tests.unit import SAGEMAKER_CONFIG_PROCESSING_JOB -from sagemaker.workflow.parameters import ParameterString, Parameter +from sagemaker.workflow.parameters import ParameterString BUCKET_NAME = "mybucket" REGION = "us-west-2" @@ -1719,16 +1719,14 @@ def _get_describe_response_inputs_and_ouputs(): "ProcessingOutputConfig": _get_expected_args_all_parameters(None)["output_config"], } + # Parameters def _get_data_inputs_with_parameters(): return [ ProcessingInput( - source=ParameterString( - name="input_data", - default_value="s3://dummy-bucket/input" - ), + source=ParameterString(name="input_data", default_value="s3://dummy-bucket/input"), destination="/opt/ml/processing/input", - input_name="input-1" + input_name="input-1", ) ] @@ -1738,36 +1736,39 @@ def _get_data_outputs_with_parameters(): ProcessingOutput( source="/opt/ml/processing/output", destination=ParameterString( - name="output_data", - default_value="s3://dummy-bucket/output" + name="output_data", default_value="s3://dummy-bucket/output" ), - output_name="output-1" + output_name="output-1", ) ] def _get_expected_args_with_parameters(job_name): return { - "inputs": [{ - "InputName": "input-1", - "S3Input": { - "S3Uri": "s3://dummy-bucket/input", - "LocalPath": "/opt/ml/processing/input", - "S3DataType": "S3Prefix", - "S3InputMode": "File", - "S3DataDistributionType": "FullyReplicated", - "S3CompressionType": "None" + "inputs": [ + { + "InputName": "input-1", + "S3Input": { + "S3Uri": "s3://dummy-bucket/input", + "LocalPath": "/opt/ml/processing/input", + "S3DataType": "S3Prefix", + "S3InputMode": "File", + "S3DataDistributionType": "FullyReplicated", + "S3CompressionType": "None", + }, } - }], + ], "output_config": { - "Outputs": [{ - "OutputName": "output-1", - "S3Output": { - "S3Uri": "s3://dummy-bucket/output", - "LocalPath": "/opt/ml/processing/output", - "S3UploadMode": "EndOfJob" + "Outputs": [ + { + "OutputName": "output-1", + "S3Output": { + "S3Uri": "s3://dummy-bucket/output", + "LocalPath": "/opt/ml/processing/output", + "S3UploadMode": "EndOfJob", + }, } - }] + ] }, "job_name": job_name, "resources": { @@ -1775,7 +1776,7 @@ def _get_expected_args_with_parameters(job_name): "InstanceType": "ml.m4.xlarge", "InstanceCount": 1, "VolumeSizeInGB": 100, - "VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key" + "VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key", } }, "stopping_condition": {"MaxRuntimeInSeconds": 3600}, @@ -1785,9 +1786,9 @@ def _get_expected_args_with_parameters(job_name): "--input-data", "s3://dummy-bucket/input-param", "--output-path", - "s3://dummy-bucket/output-param" + "s3://dummy-bucket/output-param", ], - "ContainerEntrypoint": ["python3"] + "ContainerEntrypoint": ["python3"], }, "environment": {"my_env_variable": "my_env_variable_value"}, "network_config": { @@ -1795,12 +1796,12 @@ def _get_expected_args_with_parameters(job_name): "EnableInterContainerTrafficEncryption": True, "VpcConfig": { "Subnets": ["my_subnet_id"], - "SecurityGroupIds": ["my_security_group_id"] - } + "SecurityGroupIds": ["my_security_group_id"], + }, }, "role_arn": "dummy/role", "tags": [{"Key": "my-tag", "Value": "my-tag-value"}], - "experiment_config": {"ExperimentName": "AnExperiment"} + "experiment_config": {"ExperimentName": "AnExperiment"}, } @@ -1810,12 +1811,12 @@ def _get_expected_args_with_parameters(job_name): @patch("sagemaker.utils.create_tar_file") @patch("sagemaker.session.Session.upload_data") def test_script_processor_with_parameter_string( - upload_data_mock, - create_tar_file_mock, - repack_model_mock, - exists_mock, - isfile_mock, - sagemaker_session, + upload_data_mock, + create_tar_file_mock, + repack_model_mock, + exists_mock, + isfile_mock, + sagemaker_session, ): """Test ScriptProcessor with ParameterString arguments""" upload_data_mock.return_value = "s3://mocked_s3_uri_from_upload_data" @@ -1843,21 +1844,12 @@ def test_script_processor_with_parameter_string( sagemaker_session=sagemaker_session, ) - input_param = ParameterString( - name="input_param", - default_value="s3://dummy-bucket/input-param" - ) + input_param = ParameterString(name="input_param", default_value="s3://dummy-bucket/input-param") output_param = ParameterString( - name="output_param", - default_value="s3://dummy-bucket/output-param" - ) - exec_var = ExecutionVariable( - name="ExecutionTest" - ) - join_var = Join( - on="/", - values=["s3://bucket", "prefix", "file.txt"] + name="output_param", default_value="s3://dummy-bucket/output-param" ) + exec_var = ExecutionVariable(name="ExecutionTest") + join_var = Join(on="/", values=["s3://bucket", "prefix", "file.txt"]) dummy_str_var = "test-variable" # Define expected arguments @@ -1868,15 +1860,14 @@ def test_script_processor_with_parameter_string( "AppManaged": False, "S3Input": { "S3Uri": ParameterString( - name="input_data", - default_value="s3://dummy-bucket/input" + name="input_data", default_value="s3://dummy-bucket/input" ), "LocalPath": "/opt/ml/processing/input", "S3DataType": "S3Prefix", "S3InputMode": "File", "S3DataDistributionType": "FullyReplicated", - "S3CompressionType": "None" - } + "S3CompressionType": "None", + }, }, { "InputName": "code", @@ -1887,9 +1878,9 @@ def test_script_processor_with_parameter_string( "S3DataType": "S3Prefix", "S3InputMode": "File", "S3DataDistributionType": "FullyReplicated", - "S3CompressionType": "None" - } - } + "S3CompressionType": "None", + }, + }, ], "output_config": { "Outputs": [ @@ -1898,15 +1889,14 @@ def test_script_processor_with_parameter_string( "AppManaged": False, "S3Output": { "S3Uri": ParameterString( - name="output_data", - default_value="s3://dummy-bucket/output" + name="output_data", default_value="s3://dummy-bucket/output" ), "LocalPath": "/opt/ml/processing/output", - "S3UploadMode": "EndOfJob" - } + "S3UploadMode": "EndOfJob", + }, } ], - "KmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/output-kms-key" + "KmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/output-kms-key", }, "job_name": "test_job", "resources": { @@ -1914,7 +1904,7 @@ def test_script_processor_with_parameter_string( "InstanceType": "ml.m4.xlarge", "InstanceCount": 1, "VolumeSizeInGB": 100, - "VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key" + "VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key", } }, "stopping_condition": {"MaxRuntimeInSeconds": 3600}, @@ -1922,14 +1912,17 @@ def test_script_processor_with_parameter_string( "ImageUri": "custom-image-uri", "ContainerArguments": [ "--input-data", - "s3://dummy-bucket/input-param", + '{"Get": "Parameters.input_param"}', "--output-path", - "s3://dummy-bucket/output-param", - "--exec-arg", "ExecutionTest", - "--join-arg", "s3://bucket/prefix/file.txt", - "--string-param", "test-variable" + '{"Get": "Parameters.output_param"}', + "--exec-arg", + '{"Get": "Execution.ExecutionTest"}', + "--join-arg", + '{"Std:Join": {"On": "/", "Values": ["s3://bucket", "prefix", "file.txt"]}}', + "--string-param", + "test-variable", ], - "ContainerEntrypoint": ["python3", "/opt/ml/processing/input/code/processing_code.py"] + "ContainerEntrypoint": ["python3", "/opt/ml/processing/input/code/processing_code.py"], }, "environment": {"my_env_variable": "my_env_variable_value"}, "network_config": { @@ -1937,12 +1930,12 @@ def test_script_processor_with_parameter_string( "EnableInterContainerTrafficEncryption": True, "VpcConfig": { "SecurityGroupIds": ["my_security_group_id"], - "Subnets": ["my_subnet_id"] - } + "Subnets": ["my_subnet_id"], + }, }, "role_arn": "arn:aws:iam::012345678901:role/SageMakerRole", "tags": [{"Key": "my-tag", "Value": "my-tag-value"}], - "experiment_config": {"ExperimentName": "AnExperiment"} + "experiment_config": {"ExperimentName": "AnExperiment"}, } # Run processor @@ -1955,9 +1948,12 @@ def test_script_processor_with_parameter_string( input_param, "--output-path", output_param, - "--exec-arg", exec_var, - "--join-arg", join_var, - "--string-param", dummy_str_var + "--exec-arg", + exec_var, + "--join-arg", + join_var, + "--string-param", + dummy_str_var, ], wait=True, logs=False, @@ -1968,5 +1964,3 @@ def test_script_processor_with_parameter_string( # Assert sagemaker_session.process.assert_called_with(**expected_args) assert "test_job" in processor._current_job_name - -