Skip to content

Add Owner ID check for bucket with path when prefix is provided #5146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
31 changes: 22 additions & 9 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,13 +630,12 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
s3 = self.s3_resource

bucket = s3.Bucket(name=bucket_name)
expected_bucket_owner_id = self.account_id()
if bucket.creation_date is None:
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True)
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True, expected_bucket_owner_id)

elif self._default_bucket_set_by_sdk:
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False)

expected_bucket_owner_id = self.account_id()
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False, expected_bucket_owner_id)
self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id)

def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id):
Expand All @@ -649,9 +648,16 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket

"""
try:
s3.meta.client.head_bucket(
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
)
if self.default_bucket_prefix:
s3.meta.client.list_objects_v2(
Bucket=bucket_name,
Prefix=self.default_bucket_prefix,
ExpectedBucketOwner=expected_bucket_owner_id
)
else:
s3.meta.client.head_bucket(
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
Expand All @@ -668,7 +674,7 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket
raise

def general_bucket_check_if_user_has_permission(
self, bucket_name, s3, bucket, region, bucket_creation_date_none
self, bucket_name, s3, bucket, region, bucket_creation_date_none, expected_bucket_owner_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this extra parameter is unused

):
"""Checks if the person running has the permissions to the bucket

Expand All @@ -682,7 +688,14 @@ def general_bucket_check_if_user_has_permission(
bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
"""
try:
s3.meta.client.head_bucket(Bucket=bucket_name)
if self.default_bucket_prefix:
s3.meta.client.list_objects_v2(
Bucket=bucket_name,
Prefix=self.default_bucket_prefix,
ExpectedBucketOwner=expected_bucket_owner_id
Copy link
Contributor

@benieric benieric Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this check is a bit different in this method. Looks like it just checks for general permission to the bucket rather than checking for ExpectedBucketOwner. For this case, probably need to remove the ExpectedBucketOwner=expected_bucket_owner_id to match the behavior of the second block which just does a regular head bucket check - s3.meta.client.head_bucket(Bucket=bucket_name)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good callout. Updated the PR.

)
else:
s3.meta.client.head_bucket(Bucket=bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/test_default_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ def sagemaker_session():
return sagemaker_session


@pytest.fixture()
def sagemaker_session_with_bucket_name_and_prefix():
boto_mock = MagicMock(name="boto_session", region_name=REGION)
boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID}
sagemaker_session = sagemaker.Session(boto_session=boto_mock,
default_bucket="XXXXXXXXXXXXX",
default_bucket_prefix="sample-prefix")
sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
return sagemaker_session


def test_default_bucket_s3_create_call(sagemaker_session):
error = ClientError(
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
Expand Down Expand Up @@ -95,6 +106,24 @@ def test_default_bucket_s3_needs_bucket_owner_access(sagemaker_session, datetime
assert error_message in caplog.text
assert sagemaker_session._default_bucket is None

def test_default_bucket_with_prefix_s3_needs_bucket_owner_access(sagemaker_session_with_bucket_name_and_prefix,
datetime_obj,
caplog):
with pytest.raises(ClientError):
error = ClientError(
error_response={"Error": {"Code": "403", "Message": "Forbidden"}},
operation_name="foo",
)
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").meta.client.list_objects_v2.side_effect = error
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").Bucket(
name=DEFAULT_BUCKET_NAME
).creation_date = None
sagemaker_session_with_bucket_name_and_prefix.default_bucket()

error_message = "Please try again after adding appropriate access."
assert error_message in caplog.text
assert sagemaker_session_with_bucket_name_and_prefix._default_bucket is None
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").meta.client.list_objects_v2.assert_called_once()

def test_default_bucket_s3_custom_bucket_input(sagemaker_session, datetime_obj, caplog):
sagemaker_session._default_bucket_name_override = "custom-bucket-override"
Expand Down
Loading