Skip to content

Commit ae3247a

Browse files
authored
fix: modify session and kms_utils to check for S3 bucket before creation (aws#1271)
Previously, the code would create the bucket and enact business logic based on the exceptions thrown. This commit makes it such that the code checks that the bucket exists before trying to create it. In addition to being cleaner, this also avoids issues if customers have no S3 bucket creation permissions.
1 parent ea6fc7e commit ae3247a

File tree

4 files changed

+60
-53
lines changed

4 files changed

+60
-53
lines changed

src/sagemaker/session.py

Lines changed: 47 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -342,40 +342,58 @@ def default_bucket(self):
342342
).get_caller_identity()["Account"]
343343
default_bucket = "sagemaker-{}-{}".format(region, account)
344344

345-
s3 = self.boto_session.resource("s3")
346-
try:
347-
# 'us-east-1' cannot be specified because it is the default region:
348-
# https://github.com/boto/boto3/issues/125
349-
if region == "us-east-1":
350-
s3.create_bucket(Bucket=default_bucket)
351-
else:
352-
s3.create_bucket(
353-
Bucket=default_bucket, CreateBucketConfiguration={"LocationConstraint": region}
354-
)
355-
356-
LOGGER.info("Created S3 bucket: %s", default_bucket)
357-
except ClientError as e:
358-
error_code = e.response["Error"]["Code"]
359-
message = e.response["Error"]["Message"]
360-
361-
if error_code == "BucketAlreadyOwnedByYou":
362-
pass
363-
elif (
364-
error_code == "OperationAborted" and "conflicting conditional operation" in message
365-
):
366-
# If this bucket is already being concurrently created, we don't need to create it
367-
# again.
368-
pass
369-
elif error_code == "TooManyBuckets":
370-
# Succeed if the default bucket exists
371-
s3.meta.client.head_bucket(Bucket=default_bucket)
372-
else:
373-
raise
345+
self._create_s3_bucket_if_it_does_not_exist(bucket_name=default_bucket, region=region)
374346

375347
self._default_bucket = default_bucket
376348

377349
return self._default_bucket
378350

351+
def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
352+
"""Creates an S3 Bucket if it does not exist.
353+
Also swallows a few common exceptions that indicate that the bucket already exists or
354+
that it is being created.
355+
356+
Args:
357+
bucket_name (str): Name of the S3 bucket to be created.
358+
region (str): The region in which to create the bucket.
359+
360+
Raises:
361+
botocore.exceptions.ClientError: If S3 throws an unexpected exception during bucket
362+
creation.
363+
If the exception is due to the bucket already existing or
364+
already being created, no exception is raised.
365+
366+
"""
367+
bucket = self.boto_session.resource("s3", region_name=region).Bucket(name=bucket_name)
368+
if bucket.creation_date is None:
369+
try:
370+
s3 = self.boto_session.resource("s3", region_name=region)
371+
if region == "us-east-1":
372+
# 'us-east-1' cannot be specified because it is the default region:
373+
# https://github.com/boto/boto3/issues/125
374+
s3.create_bucket(Bucket=bucket_name)
375+
else:
376+
s3.create_bucket(
377+
Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region}
378+
)
379+
380+
LOGGER.info("Created S3 bucket: %s", bucket_name)
381+
except ClientError as e:
382+
error_code = e.response["Error"]["Code"]
383+
message = e.response["Error"]["Message"]
384+
385+
if error_code == "BucketAlreadyOwnedByYou":
386+
pass
387+
elif (
388+
error_code == "OperationAborted"
389+
and "conflicting conditional operation" in message
390+
):
391+
# If this bucket is already being concurrently created, we don't need to create
392+
# it again.
393+
pass
394+
else:
395+
raise
396+
379397
def train( # noqa: C901
380398
self,
381399
input_mode,

tests/integ/kms_utils.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
import contextlib
1616
import json
1717

18-
from botocore import exceptions
19-
2018
from sagemaker import utils
2119

2220
PRINCIPAL_TEMPLATE = (
@@ -158,7 +156,8 @@ def get_or_create_kms_key(
158156

159157

160158
@contextlib.contextmanager
161-
def bucket_with_encryption(boto_session, sagemaker_role):
159+
def bucket_with_encryption(sagemaker_session, sagemaker_role):
160+
boto_session = sagemaker_session.boto_session
162161
region = boto_session.region_name
163162
sts_client = boto_session.client(
164163
"sts", region_name=region, endpoint_url=utils.sts_regional_endpoint(region)
@@ -173,22 +172,10 @@ def bucket_with_encryption(boto_session, sagemaker_role):
173172
region = boto_session.region_name
174173
bucket_name = "sagemaker-{}-{}-with-kms".format(region, account)
175174

176-
s3 = boto_session.client("s3")
177-
try:
178-
# 'us-east-1' cannot be specified because it is the default region:
179-
# https://github.com/boto/boto3/issues/125
180-
if region == "us-east-1":
181-
s3.create_bucket(Bucket=bucket_name)
182-
else:
183-
s3.create_bucket(
184-
Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region}
185-
)
186-
187-
except exceptions.ClientError as e:
188-
if e.response["Error"]["Code"] != "BucketAlreadyOwnedByYou":
189-
raise
190-
191-
s3.put_bucket_encryption(
175+
sagemaker_session._create_s3_bucket_if_it_does_not_exist(bucket_name=bucket_name, region=region)
176+
177+
s3_client = boto_session.client("s3", region_name=region)
178+
s3_client.put_bucket_encryption(
192179
Bucket=bucket_name,
193180
ServerSideEncryptionConfiguration={
194181
"Rules": [
@@ -202,7 +189,9 @@ def bucket_with_encryption(boto_session, sagemaker_role):
202189
},
203190
)
204191

205-
s3.put_bucket_policy(Bucket=bucket_name, Policy=KMS_BUCKET_POLICY % (bucket_name, bucket_name))
192+
s3_client.put_bucket_policy(
193+
Bucket=bucket_name, Policy=KMS_BUCKET_POLICY % (bucket_name, bucket_name)
194+
)
206195

207196
yield "s3://" + bucket_name, kms_key_arn
208197

tests/integ/test_tf_script_mode.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ def test_mnist_with_checkpoint_config(sagemaker_session, instance_type, tf_full_
8383

8484

8585
def test_server_side_encryption(sagemaker_session, tf_full_version):
86-
boto_session = sagemaker_session.boto_session
87-
with kms_utils.bucket_with_encryption(boto_session, ROLE) as (bucket_with_kms, kms_key):
86+
with kms_utils.bucket_with_encryption(sagemaker_session, ROLE) as (bucket_with_kms, kms_key):
8887
output_path = os.path.join(
8988
bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M")
9089
)

tests/unit/test_default_bucket.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@
2626
def sagemaker_session():
2727
boto_mock = Mock(name="boto_session", region_name=REGION)
2828
boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID}
29-
ims = sagemaker.Session(boto_session=boto_mock)
30-
return ims
29+
sagemaker_session = sagemaker.Session(boto_session=boto_mock)
30+
sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
31+
return sagemaker_session
3132

3233

3334
def test_default_bucket_s3_create_call(sagemaker_session):

0 commit comments

Comments
 (0)