Skip to content

Add Owner ID check for bucket with path when prefix is provided #5146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
21 changes: 16 additions & 5 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,6 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):

elif self._default_bucket_set_by_sdk:
self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False)

expected_bucket_owner_id = self.account_id()
self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id)

Expand All @@ -649,9 +648,16 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket

"""
try:
s3.meta.client.head_bucket(
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
)
if self.default_bucket_prefix:
s3.meta.client.list_objects_v2(
Bucket=bucket_name,
Prefix=self.default_bucket_prefix,
ExpectedBucketOwner=expected_bucket_owner_id,
)
else:
s3.meta.client.head_bucket(
Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id
)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
Expand Down Expand Up @@ -682,7 +688,12 @@ def general_bucket_check_if_user_has_permission(
bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not
"""
try:
s3.meta.client.head_bucket(Bucket=bucket_name)
if self.default_bucket_prefix:
s3.meta.client.list_objects_v2(
Bucket=bucket_name, Prefix=self.default_bucket_prefix
)
else:
s3.meta.client.head_bucket(Bucket=bucket_name)
except ClientError as e:
error_code = e.response["Error"]["Code"]
message = e.response["Error"]["Message"]
Expand Down
37 changes: 37 additions & 0 deletions tests/unit/test_default_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ def sagemaker_session():
return sagemaker_session


@pytest.fixture()
def sagemaker_session_with_bucket_name_and_prefix():
boto_mock = MagicMock(name="boto_session", region_name=REGION)
boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID}
sagemaker_session = sagemaker.Session(
boto_session=boto_mock,
default_bucket="XXXXXXXXXXXXX",
default_bucket_prefix="sample-prefix",
)
sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None
return sagemaker_session


def test_default_bucket_s3_create_call(sagemaker_session):
error = ClientError(
error_response={"Error": {"Code": "404", "Message": "Not Found"}},
Expand Down Expand Up @@ -96,6 +109,30 @@ def test_default_bucket_s3_needs_bucket_owner_access(sagemaker_session, datetime
assert sagemaker_session._default_bucket is None


def test_default_bucket_with_prefix_s3_needs_bucket_owner_access(
sagemaker_session_with_bucket_name_and_prefix, datetime_obj, caplog
):
with pytest.raises(ClientError):
error = ClientError(
error_response={"Error": {"Code": "403", "Message": "Forbidden"}},
operation_name="foo",
)
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource(
"s3"
).meta.client.list_objects_v2.side_effect = error
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").Bucket(
name=DEFAULT_BUCKET_NAME
).creation_date = None
sagemaker_session_with_bucket_name_and_prefix.default_bucket()

error_message = "Please try again after adding appropriate access."
assert error_message in caplog.text
assert sagemaker_session_with_bucket_name_and_prefix._default_bucket is None
sagemaker_session_with_bucket_name_and_prefix.boto_session.resource(
"s3"
).meta.client.list_objects_v2.assert_called_once()


def test_default_bucket_s3_custom_bucket_input(sagemaker_session, datetime_obj, caplog):
sagemaker_session._default_bucket_name_override = "custom-bucket-override"
error = ClientError(
Expand Down