diff --git a/src/sagemaker/lambda_helper.py b/src/sagemaker/lambda_helper.py index 9caf0db16b..2e74569f84 100644 --- a/src/sagemaker/lambda_helper.py +++ b/src/sagemaker/lambda_helper.py @@ -15,6 +15,7 @@ from io import BytesIO import zipfile +import time from botocore.exceptions import ClientError from sagemaker.session import Session @@ -134,32 +135,35 @@ def update(self): Returns: boto3 response from Lambda's update_function method. """ lambda_client = _get_lambda_client(self.session) - - if self.script is not None: - try: - response = lambda_client.update_function_code( - FunctionName=self.function_name, ZipFile=_zip_lambda_code(self.script) - ) - return response - except ClientError as e: - error = e.response["Error"] - raise ValueError(error) - else: + retry_attempts = 7 + for i in range(retry_attempts): try: - response = lambda_client.update_function_code( - FunctionName=(self.function_name or self.function_arn), - S3Bucket=self.s3_bucket, - S3Key=_upload_to_s3( - s3_client=_get_s3_client(self.session), - function_name=self.function_name, - zipped_code_dir=self.zipped_code_dir, - s3_bucket=self.s3_bucket, - ), - ) + if self.script is not None: + response = lambda_client.update_function_code( + FunctionName=self.function_name, ZipFile=_zip_lambda_code(self.script) + ) + else: + response = lambda_client.update_function_code( + FunctionName=(self.function_name or self.function_arn), + S3Bucket=self.s3_bucket, + S3Key=_upload_to_s3( + s3_client=_get_s3_client(self.session), + function_name=self.function_name, + zipped_code_dir=self.zipped_code_dir, + s3_bucket=self.s3_bucket, + ), + ) return response except ClientError as e: error = e.response["Error"] - raise ValueError(error) + code = error["Code"] + if code == "ResourceConflictException": + if i == retry_attempts - 1: + raise ValueError(error) + # max wait time = 2**0 + 2**1 + .. + 2**6 = 127 seconds + time.sleep(2**i) + else: + raise ValueError(error) def upsert(self): """Method to create a lambda function or update it if it already exists diff --git a/src/sagemaker/workflow/pipeline.py b/src/sagemaker/workflow/pipeline.py index e1ad50e6cf..02d584edda 100644 --- a/src/sagemaker/workflow/pipeline.py +++ b/src/sagemaker/workflow/pipeline.py @@ -215,30 +215,33 @@ def upsert( Returns: response dict from service """ + exists = True try: - response = self.create(role_arn, description, tags, parallelism_config) + self.describe() except ClientError as e: - error = e.response["Error"] - if ( - error["Code"] == "ValidationException" - and "Pipeline names must be unique within" in error["Message"] - ): - response = self.update(role_arn, description) - if tags is not None: - old_tags = self.sagemaker_session.sagemaker_client.list_tags( - ResourceArn=response["PipelineArn"] - )["Tags"] - - tag_keys = [tag["Key"] for tag in tags] - for old_tag in old_tags: - if old_tag["Key"] not in tag_keys: - tags.append(old_tag) - - self.sagemaker_session.sagemaker_client.add_tags( - ResourceArn=response["PipelineArn"], Tags=tags - ) + err = e.response.get("Error", {}) + if err.get("Code", None) == "ResourceNotFound": + exists = False else: - raise + raise e + + if not exists: + response = self.create(role_arn, description, tags, parallelism_config) + else: + response = self.update(role_arn, description) + if tags is not None: + old_tags = self.sagemaker_session.sagemaker_client.list_tags( + ResourceArn=response["PipelineArn"] + )["Tags"] + + tag_keys = [tag["Key"] for tag in tags] + for old_tag in old_tags: + if old_tag["Key"] not in tag_keys: + tags.append(old_tag) + + self.sagemaker_session.sagemaker_client.add_tags( + ResourceArn=response["PipelineArn"], Tags=tags + ) return response def delete(self) -> Dict[str, Any]: @@ -270,18 +273,6 @@ def start( Returns: A `_PipelineExecution` instance, if successful. """ - exists = True - try: - self.describe() - except ClientError: - exists = False - - if not exists: - raise ValueError( - "This pipeline is not associated with a Pipeline in SageMaker. " - "Please invoke create() first before attempting to invoke start()." - ) - kwargs = dict(PipelineName=self.name) update_args( kwargs, diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py index be90a8a876..f39a012df8 100644 --- a/tests/unit/sagemaker/workflow/test_pipeline.py +++ b/tests/unit/sagemaker/workflow/test_pipeline.py @@ -17,8 +17,6 @@ import pytest -from botocore.exceptions import ClientError - from mock import Mock from sagemaker import s3 @@ -178,20 +176,15 @@ def test_large_pipeline_update(sagemaker_session_mock, role_arn): def test_pipeline_upsert(sagemaker_session_mock, role_arn): - sagemaker_session_mock.side_effect = [ - ClientError( - operation_name="CreatePipeline", - error_response={ - "Error": { - "Code": "ValidationException", - "Message": "Pipeline names must be unique within ...", - } - }, - ), - {"PipelineArn": "mock_pipeline_arn"}, - [{"Key": "dummy", "Value": "dummy_tag"}], - {}, - ] + sagemaker_session_mock.sagemaker_client.describe_pipeline.return_value = { + "PipelineArn": "pipeline-arn" + } + sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = { + "PipelineArn": "pipeline-arn" + } + sagemaker_session_mock.sagemaker_client.list_tags.return_value = { + "Tags": [{"Key": "dummy", "Value": "dummy_tag"}] + } pipeline = Pipeline( name="MyPipeline", @@ -205,9 +198,9 @@ def test_pipeline_upsert(sagemaker_session_mock, role_arn): {"Key": "bar", "Value": "xyz"}, ] pipeline.upsert(role_arn=role_arn, tags=tags) - assert sagemaker_session_mock.sagemaker_client.create_pipeline.called_with( - PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn - ) + + sagemaker_session_mock.sagemaker_client.create_pipeline.assert_not_called() + assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with( PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn ) @@ -273,18 +266,6 @@ def test_pipeline_start(sagemaker_session_mock): ) -def test_pipeline_start_before_creation(sagemaker_session_mock): - sagemaker_session_mock.sagemaker_client.describe_pipeline.side_effect = ClientError({}, "bar") - pipeline = Pipeline( - name="MyPipeline", - parameters=[ParameterString("alpha", "beta"), ParameterString("gamma", "delta")], - steps=[], - sagemaker_session=sagemaker_session_mock, - ) - with pytest.raises(ValueError): - pipeline.start() - - def test_pipeline_basic(): parameter = ParameterString("MyStr") pipeline = Pipeline(