Skip to content

Commit 7454f04

Browse files
committed
refactor s3 bucket creation
1 parent c6f0aa7 commit 7454f04

File tree

3 files changed

+30
-43
lines changed

3 files changed

+30
-43
lines changed

src/sagemaker/session.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -342,21 +342,36 @@ def default_bucket(self):
342342
).get_caller_identity()["Account"]
343343
default_bucket = "sagemaker-{}-{}".format(region, account)
344344

345+
self.create_s3_bucket_if_it_does_not_exist(bucket_name=default_bucket, region=region)
346+
347+
self._default_bucket = default_bucket
348+
349+
return self._default_bucket
350+
351+
def create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
352+
"""Creates an S3 Bucket if it does not exist.
353+
Also swallows a few common exceptions that indicate that the bucket already exists, or
354+
that it is being created.
355+
356+
Args:
357+
bucket_name (str): Name of the S3 bucket to be created.
358+
region (str): The region in which to create the bucket.
359+
360+
"""
345361
s3 = self.boto_session.resource("s3", region_name=region)
346-
bucket = self.boto_session.resource("s3", region_name=region).Bucket(name=default_bucket)
362+
bucket = self.boto_session.resource("s3", region_name=region).Bucket(name=bucket_name)
347363
if bucket.creation_date is None:
348364
try:
349365
if region == "us-east-1":
350366
# 'us-east-1' cannot be specified because it is the default region:
351367
# https://github.com/boto/boto3/issues/125
352-
s3.create_bucket(Bucket=default_bucket)
368+
s3.create_bucket(Bucket=bucket_name)
353369
else:
354370
s3.create_bucket(
355-
Bucket=default_bucket,
356-
CreateBucketConfiguration={"LocationConstraint": region},
371+
Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region}
357372
)
358373

359-
LOGGER.info("Created S3 bucket: %s", default_bucket)
374+
LOGGER.info("Created S3 bucket: %s", bucket_name)
360375
except ClientError as e:
361376
error_code = e.response["Error"]["Code"]
362377
message = e.response["Error"]["Message"]
@@ -373,10 +388,6 @@ def default_bucket(self):
373388
else:
374389
raise
375390

376-
self._default_bucket = default_bucket
377-
378-
return self._default_bucket
379-
380391
def train( # noqa: C901
381392
self,
382393
input_mode,

tests/integ/kms_utils.py

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
import contextlib
1616
import json
1717

18-
from botocore import exceptions
19-
2018
from sagemaker import utils
2119

2220
PRINCIPAL_TEMPLATE = (
@@ -158,7 +156,8 @@ def get_or_create_kms_key(
158156

159157

160158
@contextlib.contextmanager
161-
def bucket_with_encryption(boto_session, sagemaker_role):
159+
def bucket_with_encryption(sagemaker_session, sagemaker_role):
160+
boto_session = sagemaker_session.boto_session
162161
region = boto_session.region_name
163162
sts_client = boto_session.client(
164163
"sts", region_name=region, endpoint_url=utils.sts_regional_endpoint(region)
@@ -173,34 +172,10 @@ def bucket_with_encryption(boto_session, sagemaker_role):
173172
region = boto_session.region_name
174173
bucket_name = "sagemaker-{}-{}-with-kms".format(region, account)
175174

176-
s3 = boto_session.client("s3", region_name=region)
177-
bucket = boto_session.resource("s3", region_name=region).Bucket(name=bucket_name)
178-
if bucket.creation_date is None:
179-
try:
180-
if region == "us-east-1":
181-
# 'us-east-1' cannot be specified because it is the default region:
182-
# https://github.com/boto/boto3/issues/125
183-
s3.create_bucket(Bucket=bucket_name)
184-
else:
185-
s3.create_bucket(
186-
Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region}
187-
)
188-
except exceptions.ClientError as e:
189-
error_code = e.response["Error"]["Code"]
190-
message = e.response["Error"]["Message"]
191-
192-
if error_code == "BucketAlreadyOwnedByYou":
193-
pass
194-
elif (
195-
error_code == "OperationAborted" and "conflicting conditional operation" in message
196-
):
197-
# If this bucket is already being concurrently created, we don't need to create it
198-
# again.
199-
pass
200-
else:
201-
raise
202-
203-
s3.put_bucket_encryption(
175+
sagemaker_session.create_bucket_if_it_does_not_exist(bucket_name=bucket_name, region=region)
176+
177+
s3_client = boto_session.client("s3", region_name=region)
178+
s3_client.put_bucket_encryption(
204179
Bucket=bucket_name,
205180
ServerSideEncryptionConfiguration={
206181
"Rules": [
@@ -214,7 +189,9 @@ def bucket_with_encryption(boto_session, sagemaker_role):
214189
},
215190
)
216191

217-
s3.put_bucket_policy(Bucket=bucket_name, Policy=KMS_BUCKET_POLICY % (bucket_name, bucket_name))
192+
s3_client.put_bucket_policy(
193+
Bucket=bucket_name, Policy=KMS_BUCKET_POLICY % (bucket_name, bucket_name)
194+
)
218195

219196
yield "s3://" + bucket_name, kms_key_arn
220197

tests/integ/test_tf_script_mode.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ def test_mnist_with_checkpoint_config(sagemaker_session, instance_type, tf_full_
8383

8484

8585
def test_server_side_encryption(sagemaker_session, tf_full_version):
86-
boto_session = sagemaker_session.boto_session
87-
with kms_utils.bucket_with_encryption(boto_session, ROLE) as (bucket_with_kms, kms_key):
86+
with kms_utils.bucket_with_encryption(sagemaker_session, ROLE) as (bucket_with_kms, kms_key):
8887
output_path = os.path.join(
8988
bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M")
9089
)

0 commit comments

Comments
 (0)