From 637bc2e2b6ec1bf3ff93b69e687e8ab71c6340ab Mon Sep 17 00:00:00 2001 From: Gokul A <166456257+nargokul@users.noreply.github.com> Date: Fri, 20 Dec 2024 09:31:54 -0800 Subject: [PATCH 1/7] Fix Flake8 Violations --- .../model_server/multi_model_server/inference.py | 14 ++++++++++++-- .../serve/model_server/torchserve/inference.py | 14 ++++++++++++-- .../model_server/torchserve/xgboost_inference.py | 14 ++++++++++++-- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/serve/model_server/multi_model_server/inference.py b/src/sagemaker/serve/model_server/multi_model_server/inference.py index 3cece40c5e..1d2440f5f9 100644 --- a/src/sagemaker/serve/model_server/multi_model_server/inference.py +++ b/src/sagemaker/serve/model_server/multi_model_server/inference.py @@ -45,11 +45,21 @@ def input_fn(input_data, content_type): try: if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0] + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type[0], ) # Check if preprocess method is defined and call it diff --git a/src/sagemaker/serve/model_server/torchserve/inference.py b/src/sagemaker/serve/model_server/torchserve/inference.py index 294c032ccc..489cc1bc1e 100644 --- a/src/sagemaker/serve/model_server/torchserve/inference.py +++ b/src/sagemaker/serve/model_server/torchserve/inference.py @@ -67,11 +67,21 @@ def input_fn(input_data, content_type): try: if hasattr(schema_builder, "custom_input_translator"): deserialized_data = schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type, ) else: deserialized_data = schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0] + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type[0], ) # Check if preprocess method is defined and call it diff --git a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py index 6dab9bc6c6..517c774bbc 100644 --- a/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py +++ b/src/sagemaker/serve/model_server/torchserve/xgboost_inference.py @@ -70,11 +70,21 @@ def input_fn(input_data, content_type): try: if hasattr(schema_builder, "custom_input_translator"): return schema_builder.custom_input_translator.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type, ) else: return schema_builder.input_deserializer.deserialize( - io.BytesIO(input_data) if type(input_data)== bytes else io.BytesIO(input_data.encode('utf-8')), content_type[0] + ( + io.BytesIO(input_data) + if type(input_data) == bytes + else io.BytesIO(input_data.encode("utf-8")) + ), + content_type[0], ) except Exception as e: raise Exception("Encountered error in deserialize_request.") from e From b88493b01d152d4bfd61b26a40a515f960308589 Mon Sep 17 00:00:00 2001 From: Gokul A Date: Fri, 25 Apr 2025 11:05:46 -0700 Subject: [PATCH 2/7] Add Owner ID check for bucket with path when prefix is provided **Description** Previously we called the head_bucket call to ensure the owner ID check, but this doesnt take into consideration cases where the s3 path is provided through the prefix. This change makes sure that director level permissions are supported. **Testing Done** Tested through unit tests, integ tests and manual testing through the installation file. Yes --- src/sagemaker/session.py | 31 ++++++++++++++++++++++--------- tests/unit/test_default_bucket.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 9 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 797d559348..8e205bf278 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -630,13 +630,12 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): s3 = self.s3_resource bucket = s3.Bucket(name=bucket_name) + expected_bucket_owner_id = self.account_id() if bucket.creation_date is None: - self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True) + self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True, expected_bucket_owner_id) elif self._default_bucket_set_by_sdk: - self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False) - - expected_bucket_owner_id = self.account_id() + self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False, expected_bucket_owner_id) self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id) def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id): @@ -649,9 +648,16 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket """ try: - s3.meta.client.head_bucket( - Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id - ) + if self.default_bucket_prefix: + s3.meta.client.list_objects_v2( + Bucket=bucket_name, + Prefix=self.default_bucket_prefix, + ExpectedBucketOwner=expected_bucket_owner_id + ) + else: + s3.meta.client.head_bucket( + Bucket=bucket_name, ExpectedBucketOwner=expected_bucket_owner_id + ) except ClientError as e: error_code = e.response["Error"]["Code"] message = e.response["Error"]["Message"] @@ -668,7 +674,7 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket raise def general_bucket_check_if_user_has_permission( - self, bucket_name, s3, bucket, region, bucket_creation_date_none + self, bucket_name, s3, bucket, region, bucket_creation_date_none, expected_bucket_owner_id ): """Checks if the person running has the permissions to the bucket @@ -682,7 +688,14 @@ def general_bucket_check_if_user_has_permission( bucket_creation_date_none (bool):Indicating whether S3 bucket already exists or not """ try: - s3.meta.client.head_bucket(Bucket=bucket_name) + if self.default_bucket_prefix: + s3.meta.client.list_objects_v2( + Bucket=bucket_name, + Prefix=self.default_bucket_prefix, + ExpectedBucketOwner=expected_bucket_owner_id + ) + else: + s3.meta.client.head_bucket(Bucket=bucket_name) except ClientError as e: error_code = e.response["Error"]["Code"] message = e.response["Error"]["Message"] diff --git a/tests/unit/test_default_bucket.py b/tests/unit/test_default_bucket.py index 6ce4b50c75..59d758138a 100644 --- a/tests/unit/test_default_bucket.py +++ b/tests/unit/test_default_bucket.py @@ -39,6 +39,17 @@ def sagemaker_session(): return sagemaker_session +@pytest.fixture() +def sagemaker_session_with_bucket_name_and_prefix(): + boto_mock = MagicMock(name="boto_session", region_name=REGION) + boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID} + sagemaker_session = sagemaker.Session(boto_session=boto_mock, + default_bucket="XXXXXXXXXXXXX", + default_bucket_prefix="sample-prefix") + sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None + return sagemaker_session + + def test_default_bucket_s3_create_call(sagemaker_session): error = ClientError( error_response={"Error": {"Code": "404", "Message": "Not Found"}}, @@ -95,6 +106,24 @@ def test_default_bucket_s3_needs_bucket_owner_access(sagemaker_session, datetime assert error_message in caplog.text assert sagemaker_session._default_bucket is None +def test_default_bucket_with_prefix_s3_needs_bucket_owner_access(sagemaker_session_with_bucket_name_and_prefix, + datetime_obj, + caplog): + with pytest.raises(ClientError): + error = ClientError( + error_response={"Error": {"Code": "403", "Message": "Forbidden"}}, + operation_name="foo", + ) + sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").meta.client.list_objects_v2.side_effect = error + sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").Bucket( + name=DEFAULT_BUCKET_NAME + ).creation_date = None + sagemaker_session_with_bucket_name_and_prefix.default_bucket() + + error_message = "Please try again after adding appropriate access." + assert error_message in caplog.text + assert sagemaker_session_with_bucket_name_and_prefix._default_bucket is None + sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").meta.client.list_objects_v2.assert_called_once() def test_default_bucket_s3_custom_bucket_input(sagemaker_session, datetime_obj, caplog): sagemaker_session._default_bucket_name_override = "custom-bucket-override" From 61b69c623fbf453ce35eab8547b36c8789f757b3 Mon Sep 17 00:00:00 2001 From: Gokul A Date: Fri, 25 Apr 2025 11:16:24 -0700 Subject: [PATCH 3/7] Address PR comment --- src/sagemaker/session.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 8e205bf278..1ef80592d8 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -691,8 +691,7 @@ def general_bucket_check_if_user_has_permission( if self.default_bucket_prefix: s3.meta.client.list_objects_v2( Bucket=bucket_name, - Prefix=self.default_bucket_prefix, - ExpectedBucketOwner=expected_bucket_owner_id + Prefix=self.default_bucket_prefix ) else: s3.meta.client.head_bucket(Bucket=bucket_name) From 6e735fac95ce4dc12c13a57ae9db32c044340e72 Mon Sep 17 00:00:00 2001 From: Gokul A Date: Fri, 25 Apr 2025 11:27:43 -0700 Subject: [PATCH 4/7] Codestyle fixes --- src/sagemaker/session.py | 13 ++++++++----- tests/unit/test_default_bucket.py | 24 ++++++++++++++++-------- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 1ef80592d8..58f40b8a9c 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -632,10 +632,14 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): bucket = s3.Bucket(name=bucket_name) expected_bucket_owner_id = self.account_id() if bucket.creation_date is None: - self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True, expected_bucket_owner_id) + self.general_bucket_check_if_user_has_permission( + bucket_name, s3, bucket, region, True, expected_bucket_owner_id + ) elif self._default_bucket_set_by_sdk: - self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False, expected_bucket_owner_id) + self.general_bucket_check_if_user_has_permission( + bucket_name, s3, bucket, region, False, expected_bucket_owner_id + ) self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id) def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id): @@ -652,7 +656,7 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket s3.meta.client.list_objects_v2( Bucket=bucket_name, Prefix=self.default_bucket_prefix, - ExpectedBucketOwner=expected_bucket_owner_id + ExpectedBucketOwner=expected_bucket_owner_id, ) else: s3.meta.client.head_bucket( @@ -690,8 +694,7 @@ def general_bucket_check_if_user_has_permission( try: if self.default_bucket_prefix: s3.meta.client.list_objects_v2( - Bucket=bucket_name, - Prefix=self.default_bucket_prefix + Bucket=bucket_name, Prefix=self.default_bucket_prefix ) else: s3.meta.client.head_bucket(Bucket=bucket_name) diff --git a/tests/unit/test_default_bucket.py b/tests/unit/test_default_bucket.py index 59d758138a..dca1d3dc85 100644 --- a/tests/unit/test_default_bucket.py +++ b/tests/unit/test_default_bucket.py @@ -43,9 +43,11 @@ def sagemaker_session(): def sagemaker_session_with_bucket_name_and_prefix(): boto_mock = MagicMock(name="boto_session", region_name=REGION) boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID} - sagemaker_session = sagemaker.Session(boto_session=boto_mock, - default_bucket="XXXXXXXXXXXXX", - default_bucket_prefix="sample-prefix") + sagemaker_session = sagemaker.Session( + boto_session=boto_mock, + default_bucket="XXXXXXXXXXXXX", + default_bucket_prefix="sample-prefix", + ) sagemaker_session.boto_session.resource("s3").Bucket().creation_date = None return sagemaker_session @@ -106,15 +108,18 @@ def test_default_bucket_s3_needs_bucket_owner_access(sagemaker_session, datetime assert error_message in caplog.text assert sagemaker_session._default_bucket is None -def test_default_bucket_with_prefix_s3_needs_bucket_owner_access(sagemaker_session_with_bucket_name_and_prefix, - datetime_obj, - caplog): + +def test_default_bucket_with_prefix_s3_needs_bucket_owner_access( + sagemaker_session_with_bucket_name_and_prefix, datetime_obj, caplog +): with pytest.raises(ClientError): error = ClientError( error_response={"Error": {"Code": "403", "Message": "Forbidden"}}, operation_name="foo", ) - sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").meta.client.list_objects_v2.side_effect = error + sagemaker_session_with_bucket_name_and_prefix.boto_session.resource( + "s3" + ).meta.client.list_objects_v2.side_effect = error sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").Bucket( name=DEFAULT_BUCKET_NAME ).creation_date = None @@ -123,7 +128,10 @@ def test_default_bucket_with_prefix_s3_needs_bucket_owner_access(sagemaker_sessi error_message = "Please try again after adding appropriate access." assert error_message in caplog.text assert sagemaker_session_with_bucket_name_and_prefix._default_bucket is None - sagemaker_session_with_bucket_name_and_prefix.boto_session.resource("s3").meta.client.list_objects_v2.assert_called_once() + sagemaker_session_with_bucket_name_and_prefix.boto_session.resource( + "s3" + ).meta.client.list_objects_v2.assert_called_once() + def test_default_bucket_s3_custom_bucket_input(sagemaker_session, datetime_obj, caplog): sagemaker_session._default_bucket_name_override = "custom-bucket-override" From 8f4e57f4def8e38f2d5496fa97413c56aefe7459 Mon Sep 17 00:00:00 2001 From: Gokul A Date: Fri, 25 Apr 2025 11:59:16 -0700 Subject: [PATCH 5/7] Minor fix --- src/sagemaker/session.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 58f40b8a9c..f9a3699493 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -633,12 +633,12 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): expected_bucket_owner_id = self.account_id() if bucket.creation_date is None: self.general_bucket_check_if_user_has_permission( - bucket_name, s3, bucket, region, True, expected_bucket_owner_id + bucket_name, s3, bucket, region, True ) elif self._default_bucket_set_by_sdk: self.general_bucket_check_if_user_has_permission( - bucket_name, s3, bucket, region, False, expected_bucket_owner_id + bucket_name, s3, bucket, region, False ) self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id) @@ -678,7 +678,7 @@ def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket raise def general_bucket_check_if_user_has_permission( - self, bucket_name, s3, bucket, region, bucket_creation_date_none, expected_bucket_owner_id + self, bucket_name, s3, bucket, region, bucket_creation_date_none ): """Checks if the person running has the permissions to the bucket From 9f4ea667d8442f40d57088971acd36cd74255c80 Mon Sep 17 00:00:00 2001 From: Gokul A Date: Fri, 25 Apr 2025 12:22:26 -0700 Subject: [PATCH 6/7] Codestyle fixes --- src/sagemaker/session.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index f9a3699493..95383ebf51 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -632,14 +632,10 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): bucket = s3.Bucket(name=bucket_name) expected_bucket_owner_id = self.account_id() if bucket.creation_date is None: - self.general_bucket_check_if_user_has_permission( - bucket_name, s3, bucket, region, True - ) + self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True) elif self._default_bucket_set_by_sdk: - self.general_bucket_check_if_user_has_permission( - bucket_name, s3, bucket, region, False - ) + self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False) self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id) def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id): From e42fb1a3d52b049ac013974b16889a4f59cbdba9 Mon Sep 17 00:00:00 2001 From: Gokul A Date: Fri, 25 Apr 2025 14:43:02 -0700 Subject: [PATCH 7/7] Fix Unit tests --- src/sagemaker/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 95383ebf51..2cc18f6989 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -630,12 +630,12 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region): s3 = self.s3_resource bucket = s3.Bucket(name=bucket_name) - expected_bucket_owner_id = self.account_id() if bucket.creation_date is None: self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, True) elif self._default_bucket_set_by_sdk: self.general_bucket_check_if_user_has_permission(bucket_name, s3, bucket, region, False) + expected_bucket_owner_id = self.account_id() self.expected_bucket_owner_id_bucket_check(bucket_name, s3, expected_bucket_owner_id) def expected_bucket_owner_id_bucket_check(self, bucket_name, s3, expected_bucket_owner_id):