Skip to content

Commit ec7cf62

Browse files
mufaddal-rohawalajerrypeng7773
authored andcommitted
fix: bucket exists check for session.default_bucket (aws#3165)
1 parent 4d27ff4 commit ec7cf62

File tree

2 files changed

+76
-26
lines changed

2 files changed

+76
-26
lines changed

src/sagemaker/session.py

+37-19
Original file line numberDiff line numberDiff line change
@@ -412,29 +412,47 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
412412
bucket = s3.Bucket(name=bucket_name)
413413
if bucket.creation_date is None:
414414
try:
415-
if region == "us-east-1":
416-
# 'us-east-1' cannot be specified because it is the default region:
417-
# https://github.com/boto/boto3/issues/125
418-
s3.create_bucket(Bucket=bucket_name)
419-
else:
420-
s3.create_bucket(
421-
Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region}
422-
)
423-
424-
LOGGER.info("Created S3 bucket: %s", bucket_name)
415+
# trying head bucket call
416+
s3.meta.client.head_bucket(Bucket=bucket.name)
425417
except ClientError as e:
418+
# bucket does not exist or forbidden to access
426419
error_code = e.response["Error"]["Code"]
427420
message = e.response["Error"]["Message"]
428421

429-
if error_code == "BucketAlreadyOwnedByYou":
430-
pass
431-
elif (
432-
error_code == "OperationAborted"
433-
and "conflicting conditional operation" in message
434-
):
435-
# If this bucket is already being concurrently created, we don't need to create
436-
# it again.
437-
pass
422+
if error_code == "404" and message == "Not Found":
423+
# bucket does not exist, create one
424+
try:
425+
if region == "us-east-1":
426+
# 'us-east-1' cannot be specified because it is the default region:
427+
# https://github.com/boto/boto3/issues/125
428+
s3.create_bucket(Bucket=bucket_name)
429+
else:
430+
s3.create_bucket(
431+
Bucket=bucket_name,
432+
CreateBucketConfiguration={"LocationConstraint": region},
433+
)
434+
435+
LOGGER.info("Created S3 bucket: %s", bucket_name)
436+
except ClientError as e:
437+
error_code = e.response["Error"]["Code"]
438+
message = e.response["Error"]["Message"]
439+
440+
if (
441+
error_code == "OperationAborted"
442+
and "conflicting conditional operation" in message
443+
):
444+
# If this bucket is already being concurrently created,
445+
# we don't need to create it again.
446+
pass
447+
else:
448+
raise
449+
elif error_code == "403" and message == "Forbidden":
450+
LOGGER.error(
451+
"Bucket %s exists, but access is forbidden. Please try again after "
452+
"adding appropriate access.",
453+
bucket.name,
454+
)
455+
raise
438456
else:
439457
raise
440458

tests/unit/test_default_bucket.py

+39-7
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import pytest
1616
from botocore.exceptions import ClientError
17-
from mock import MagicMock
17+
from mock import MagicMock, patch
1818
import sagemaker
1919

2020
ACCOUNT_ID = "123"
@@ -32,6 +32,11 @@ def sagemaker_session():
3232

3333

3434
def test_default_bucket_s3_create_call(sagemaker_session):
35+
error = ClientError(
36+
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
37+
operation_name="foo",
38+
)
39+
sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error
3540
bucket_name = sagemaker_session.default_bucket()
3641

3742
create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls
@@ -45,6 +50,25 @@ def test_default_bucket_s3_create_call(sagemaker_session):
4550
assert sagemaker_session._default_bucket == bucket_name
4651

4752

53+
def test_default_bucket_s3_needs_access(sagemaker_session):
54+
with patch("logging.Logger.error") as mocked_error_log:
55+
with pytest.raises(ClientError):
56+
error = ClientError(
57+
error_response={"Error": {"Code": "403", "Message": "Forbidden"}},
58+
operation_name="foo",
59+
)
60+
sagemaker_session.boto_session.resource(
61+
"s3"
62+
).meta.client.head_bucket.side_effect = error
63+
sagemaker_session.default_bucket()
64+
mocked_error_log.assert_called_once_with(
65+
"Bucket %s exists, but access is forbidden. Please try again after "
66+
"adding appropriate access.",
67+
DEFAULT_BUCKET_NAME,
68+
)
69+
assert sagemaker_session._default_bucket is None
70+
71+
4872
def test_default_already_cached(sagemaker_session):
4973
existing_default = "mydefaultbucket"
5074
sagemaker_session._default_bucket = existing_default
@@ -57,11 +81,9 @@ def test_default_already_cached(sagemaker_session):
5781

5882

5983
def test_default_bucket_exists(sagemaker_session):
60-
error = ClientError(
61-
error_response={"Error": {"Code": "BucketAlreadyOwnedByYou", "Message": "message"}},
62-
operation_name="foo",
63-
)
64-
sagemaker_session.boto_session.resource().create_bucket.side_effect = error
84+
sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.return_value = {
85+
"ResponseMetadata": {"RequestId": "xxx", "HTTPStatusCode": 200, "RetryAttempts": 0}
86+
}
6587

6688
bucket_name = sagemaker_session.default_bucket()
6789
assert bucket_name == DEFAULT_BUCKET_NAME
@@ -70,7 +92,7 @@ def test_default_bucket_exists(sagemaker_session):
7092
def test_concurrent_bucket_modification(sagemaker_session):
7193
message = "A conflicting conditional operation is currently in progress against this resource. Please try again"
7294
error = ClientError(
73-
error_response={"Error": {"Code": "BucketAlreadyOwnedByYou", "Message": message}},
95+
error_response={"Error": {"Code": "OperationAborted", "Message": message}},
7496
operation_name="foo",
7597
)
7698
sagemaker_session.boto_session.resource().create_bucket.side_effect = error
@@ -80,6 +102,11 @@ def test_concurrent_bucket_modification(sagemaker_session):
80102

81103

82104
def test_bucket_creation_client_error(sagemaker_session):
105+
error = ClientError(
106+
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
107+
operation_name="foo",
108+
)
109+
sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error
83110
with pytest.raises(ClientError):
84111
error = ClientError(
85112
error_response={"Error": {"Code": "SomethingWrong", "Message": "message"}},
@@ -92,6 +119,11 @@ def test_bucket_creation_client_error(sagemaker_session):
92119

93120

94121
def test_bucket_creation_other_error(sagemaker_session):
122+
error = ClientError(
123+
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
124+
operation_name="foo",
125+
)
126+
sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error
95127
with pytest.raises(RuntimeError):
96128
error = RuntimeError()
97129
sagemaker_session.boto_session.resource().create_bucket.side_effect = error

0 commit comments

Comments
 (0)