Skip to content

Commit b88493b

Browse files
committed
Add Owner ID check for bucket with path when prefix is provided
**Description** Previously we called the head_bucket call to ensure the owner ID check, but this doesnt take into consideration cases where the s3 path is provided through the prefix. This change makes sure that director level permissions are supported. **Testing Done** Tested through unit tests, integ tests and manual testing through the installation file. Yes
1 parent 53cb6f8 commit b88493b

File tree

2 files changed

+51
-9
lines changed

2 files changed

+51
-9
lines changed

src/sagemaker/session.py

+22-9
Original file line numberDiff line numberDiff line change
@@ -630,13 +630,12 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
630630
s3 = self.s3_resource
631631

632632
bucket = s3.Bucket(name=bucket_name)
633+
expected_bucket_owner_id = self.account_id()
633634
if bucket.creation_date is None:
634-
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True)
635+
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True, expected_bucket_owner_id)
635636

636637
elif self._default_bucket_set_by_sdk:
637-
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False)
638-
639-
expected_bucket_owner_id = self.account_id()
638+
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False, expected_bucket_owner_id)
640639
self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id)
641640

642641
def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id):
@@ -649,9 +648,16 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket
649648
650649
"""
651650
try:
652-
s3.meta.client.head_bucket(
653-
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
654-
)
651+
if self.default_bucket_prefix:
652+
s3.meta.client.list_objects_v2(
653+
Bucket=bucket_name,
654+
Prefix=self.default_bucket_prefix,
655+
ExpectedBucketOwner=expected_bucket_owner_id
656+
)
657+
else:
658+
s3.meta.client.head_bucket(
659+
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
660+
)
655661
except ClientError as e:
656662
error_code = e.response["Error"]["Code"]
657663
message = e.response["Error"]["Message"]
@@ -668,7 +674,7 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket
668674
raise
669675

670676
def general_bucket_check_if_user_has_permission(
671-
self, bucket_name, s3, bucket, region, bucket_creation_date_none
677+
self, bucket_name, s3, bucket, region, bucket_creation_date_none, expected_bucket_owner_id
672678
):
673679
"""Checks if the person running has the permissions to the bucket
674680
@@ -682,7 +688,14 @@ def general_bucket_check_if_user_has_permission(
682688
bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
683689
"""
684690
try:
685-
s3.meta.client.head_bucket(Bucket=bucket_name)
691+
if self.default_bucket_prefix:
692+
s3.meta.client.list_objects_v2(
693+
Bucket=bucket_name,
694+
Prefix=self.default_bucket_prefix,
695+
ExpectedBucketOwner=expected_bucket_owner_id
696+
)
697+
else:
698+
s3.meta.client.head_bucket(Bucket=bucket_name)
686699
except ClientError as e:
687700
error_code = e.response["Error"]["Code"]
688701
message = e.response["Error"]["Message"]

tests/unit/test_default_bucket.py

+29
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ def sagemaker_session():
3939
return sagemaker_session
4040

4141

42+
@pytest.fixture()
43+
def sagemaker_session_with_bucket_name_and_prefix():
44+
boto_mock = MagicMock(name="boto_session", region_name=REGION)
45+
boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID}
46+
sagemaker_session = sagemaker.Session(boto_session=boto_mock,
47+
default_bucket="XXXXXXXXXXXXX",
48+
default_bucket_prefix="sample-prefix")
49+
sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
50+
return sagemaker_session
51+
52+
4253
def test_default_bucket_s3_create_call(sagemaker_session):
4354
error = ClientError(
4455
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
@@ -95,6 +106,24 @@ def test_default_bucket_s3_needs_bucket_owner_access(sagemaker_session, datetime
95106
assert error_message in caplog.text
96107
assert sagemaker_session._default_bucket is None
97108

109+
def test_default_bucket_with_prefix_s3_needs_bucket_owner_access(sagemaker_session_with_bucket_name_and_prefix,
110+
datetime_obj,
111+
caplog):
112+
with pytest.raises(ClientError):
113+
error = ClientError(
114+
error_response={"Error": {"Code": "403", "Message": "Forbidden"}},
115+
operation_name="foo",
116+
)
117+
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").meta.client.list_objects_v2.side_effect = error
118+
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").Bucket(
119+
name=DEFAULT_BUCKET_NAME
120+
).creation_date = None
121+
sagemaker_session_with_bucket_name_and_prefix.default_bucket()
122+
123+
error_message = "Please try again after adding appropriate access."
124+
assert error_message in caplog.text
125+
assert sagemaker_session_with_bucket_name_and_prefix._default_bucket is None
126+
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").meta.client.list_objects_v2.assert_called_once()
98127

99128
def test_default_bucket_s3_custom_bucket_input(sagemaker_session, datetime_obj, caplog):
100129
sagemaker_session._default_bucket_name_override = "custom-bucket-override"

0 commit comments

Comments
 (0)