|
12 | 12 | # language governing permissions and limitations under the License.
|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
| 15 | +import datetime |
15 | 16 | import pytest
|
16 | 17 | from botocore.exceptions import ClientError
|
17 |
| -from mock import MagicMock, patch |
| 18 | +from mock import MagicMock |
18 | 19 | import sagemaker
|
19 | 20 |
|
20 | 21 | ACCOUNT_ID = "123"
|
21 | 22 | REGION = "us-west-2"
|
22 | 23 | DEFAULT_BUCKET_NAME = "sagemaker-{}-{}".format(REGION, ACCOUNT_ID)
|
23 | 24 |
|
24 | 25 |
|
| 26 | +@pytest.fixture |
| 27 | +def datetime_obj(): |
| 28 | + return datetime.datetime(2017, 6, 16, 15, 55, 0) |
| 29 | + |
| 30 | + |
25 | 31 | @pytest.fixture()
|
26 | 32 | def sagemaker_session():
|
27 | 33 | boto_mock = MagicMock(name="boto_session", region_name=REGION)
|
@@ -50,23 +56,53 @@ def test_default_bucket_s3_create_call(sagemaker_session):
|
50 | 56 | assert sagemaker_session._default_bucket == bucket_name
|
51 | 57 |
|
52 | 58 |
|
53 |
| -def test_default_bucket_s3_needs_access(sagemaker_session): |
54 |
| - with patch("logging.Logger.error") as mocked_error_log: |
55 |
| - with pytest.raises(ClientError): |
56 |
| - error = ClientError( |
57 |
| - error_response={"Error": {"Code": "403", "Message": "Forbidden"}}, |
58 |
| - operation_name="foo", |
59 |
| - ) |
60 |
| - sagemaker_session.boto_session.resource( |
61 |
| - "s3" |
62 |
| - ).meta.client.head_bucket.side_effect = error |
63 |
| - sagemaker_session.default_bucket() |
64 |
| - mocked_error_log.assert_called_once_with( |
65 |
| - "Bucket %s exists, but access is forbidden. Please try again after " |
66 |
| - "adding appropriate access.", |
67 |
| - DEFAULT_BUCKET_NAME, |
68 |
| - ) |
69 |
| - assert sagemaker_session._default_bucket is None |
| 59 | +def test_default_bucket_s3_needs_access(sagemaker_session, caplog): |
| 60 | + with pytest.raises(ClientError): |
| 61 | + error = ClientError( |
| 62 | + error_response={"Error": {"Code": "403", "Message": "Forbidden"}}, |
| 63 | + operation_name="foo", |
| 64 | + ) |
| 65 | + sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error |
| 66 | + sagemaker_session.default_bucket() |
| 67 | + error_message = ( |
| 68 | + " exists, but access is forbidden. Please try again after adding appropriate access." |
| 69 | + ) |
| 70 | + assert error_message in caplog.text |
| 71 | + assert sagemaker_session._default_bucket is None |
| 72 | + |
| 73 | + |
| 74 | +def test_default_bucket_s3_needs_bucket_owner_access(sagemaker_session, datetime_obj, caplog): |
| 75 | + with pytest.raises(ClientError): |
| 76 | + error = ClientError( |
| 77 | + error_response={"Error": {"Code": "403", "Message": "Forbidden"}}, |
| 78 | + operation_name="foo", |
| 79 | + ) |
| 80 | + sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error |
| 81 | + # bucket exists |
| 82 | + sagemaker_session.boto_session.resource("s3").Bucket( |
| 83 | + name=DEFAULT_BUCKET_NAME |
| 84 | + ).creation_date = datetime_obj |
| 85 | + sagemaker_session.default_bucket() |
| 86 | + |
| 87 | + error_message = "This bucket cannot be configured to use as it is not owned by Account" |
| 88 | + assert error_message in caplog.text |
| 89 | + assert sagemaker_session._default_bucket is None |
| 90 | + |
| 91 | + |
| 92 | +def test_default_bucket_s3_custom_bucket_input(sagemaker_session, datetime_obj, caplog): |
| 93 | + sagemaker_session._default_bucket_name_override = "custom-bucket-override" |
| 94 | + error = ClientError( |
| 95 | + error_response={"Error": {"Code": "403", "Message": "Forbidden"}}, |
| 96 | + operation_name="foo", |
| 97 | + ) |
| 98 | + sagemaker_session.boto_session.resource("s3").meta.client.head_bucket.side_effect = error |
| 99 | + # bucket exists |
| 100 | + sagemaker_session.boto_session.resource("s3").Bucket( |
| 101 | + name=DEFAULT_BUCKET_NAME |
| 102 | + ).creation_date = datetime_obj |
| 103 | + # This should not raise ClientError as no head_bucket call is expected for custom bucket |
| 104 | + sagemaker_session.default_bucket() |
| 105 | + assert sagemaker_session._default_bucket == "custom-bucket-override" |
70 | 106 |
|
71 | 107 |
|
72 | 108 | def test_default_already_cached(sagemaker_session):
|
|
0 commit comments