diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 44458438c4..461dfd8bab 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -412,29 +412,47 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): bucket = s3.Bucket(name=bucket_name) if bucket.creation_date is None: try: - if region == "us-east-1": - # 'us-east-1' cannot be specified because it is the default region: - # https://github.com/boto/boto3/issues/125 - s3.create_bucket(Bucket=bucket_name) - else: - s3.create_bucket( - Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region} - ) - - LOGGER.info("Created S3 bucket: %s", bucket_name) + # trying head bucket call + s3.meta.client.head_bucket(Bucket=bucket.name) except ClientError as e: + # bucket does not exist or forbidden to access error_code = e.response["Error"]["Code"] message = e.response["Error"]["Message"] - if error_code == "BucketAlreadyOwnedByYou": - pass - elif ( - error_code == "OperationAborted" - and "conflicting conditional operation" in message - ): - # If this bucket is already being concurrently created, we don't need to create - # it again. - pass + if error_code == "404" and message == "Not Found": + # bucket does not exist, create one + try: + if region == "us-east-1": + # 'us-east-1' cannot be specified because it is the default region: + # https://github.com/boto/boto3/issues/125 + s3.create_bucket(Bucket=bucket_name) + else: + s3.create_bucket( + Bucket=bucket_name, + CreateBucketConfiguration={"LocationConstraint": region}, + ) + + LOGGER.info("Created S3 bucket: %s", bucket_name) + except ClientError as e: + error_code = e.response["Error"]["Code"] + message = e.response["Error"]["Message"] + + if ( + error_code == "OperationAborted" + and "conflicting conditional operation" in message + ): + # If this bucket is already being concurrently created, + # we don't need to create it again. + pass + else: + raise + elif error_code == "403" and message == "Forbidden": + LOGGER.error( + "Bucket %s exists, but access is forbidden. Please try again after " + "adding appropriate access.", + bucket.name, + ) + raise else: raise diff --git a/tests/unit/test_default_bucket.py b/tests/unit/test_default_bucket.py index 457a307d12..01dce9ed45 100644 --- a/tests/unit/test_default_bucket.py +++ b/tests/unit/test_default_bucket.py @@ -14,7 +14,7 @@ import pytest from botocore.exceptions import ClientError -from mock import MagicMock +from mock import MagicMock, patch import sagemaker ACCOUNT_ID = "123" @@ -32,6 +32,11 @@ def sagemaker_session(): def test_default_bucket_s3_create_call(sagemaker_session): + error = ClientError( + error_response={"Error": {"Code": "404", "Message": "Not Found"}}, + operation_name="foo", + ) + sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error bucket_name = sagemaker_session.default_bucket() create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls @@ -45,6 +50,25 @@ def test_default_bucket_s3_create_call(sagemaker_session): assert sagemaker_session._default_bucket == bucket_name +def test_default_bucket_s3_needs_access(sagemaker_session): + with patch("logging.Logger.error") as mocked_error_log: + with pytest.raises(ClientError): + error = ClientError( + error_response={"Error": {"Code": "403", "Message": "Forbidden"}}, + operation_name="foo", + ) + sagemaker_session.boto_session.resource( + "s3" + ).meta.client.head_bucket.side_effect = error + sagemaker_session.default_bucket() + mocked_error_log.assert_called_once_with( + "Bucket %s exists, but access is forbidden. Please try again after " + "adding appropriate access.", + DEFAULT_BUCKET_NAME, + ) + assert sagemaker_session._default_bucket is None + + def test_default_already_cached(sagemaker_session): existing_default = "mydefaultbucket" sagemaker_session._default_bucket = existing_default @@ -57,11 +81,9 @@ def test_default_already_cached(sagemaker_session): def test_default_bucket_exists(sagemaker_session): - error = ClientError( - error_response={"Error": {"Code": "BucketAlreadyOwnedByYou", "Message": "message"}}, - operation_name="foo", - ) - sagemaker_session.boto_session.resource().create_bucket.side_effect = error + sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.return_value = { + "ResponseMetadata": {"RequestId": "xxx", "HTTPStatusCode": 200, "RetryAttempts": 0} + } bucket_name = sagemaker_session.default_bucket() assert bucket_name == DEFAULT_BUCKET_NAME @@ -70,7 +92,7 @@ def test_default_bucket_exists(sagemaker_session): def test_concurrent_bucket_modification(sagemaker_session): message = "A conflicting conditional operation is currently in progress against this resource. Please try again" error = ClientError( - error_response={"Error": {"Code": "BucketAlreadyOwnedByYou", "Message": message}}, + error_response={"Error": {"Code": "OperationAborted", "Message": message}}, operation_name="foo", ) sagemaker_session.boto_session.resource().create_bucket.side_effect = error @@ -80,6 +102,11 @@ def test_concurrent_bucket_modification(sagemaker_session): def test_bucket_creation_client_error(sagemaker_session): + error = ClientError( + error_response={"Error": {"Code": "404", "Message": "Not Found"}}, + operation_name="foo", + ) + sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error with pytest.raises(ClientError): error = ClientError( error_response={"Error": {"Code": "SomethingWrong", "Message": "message"}}, @@ -92,6 +119,11 @@ def test_bucket_creation_client_error(sagemaker_session): def test_bucket_creation_other_error(sagemaker_session): + error = ClientError( + error_response={"Error": {"Code": "404", "Message": "Not Found"}}, + operation_name="foo", + ) + sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error with pytest.raises(RuntimeError): error = RuntimeError() sagemaker_session.boto_session.resource().create_bucket.side_effect = error