Skip to content

Commit e980291

Browse files
authored
Revert "feature: allow setting the default bucket in Session (#1168)" (#1175)
This reverts commit 98fa76d.
1 parent 3b4ba23 commit e980291

12 files changed

+19
-267
lines changed

src/sagemaker/local/local_session.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def __init__(self, boto_session=None):
379379
if platform.system() == "Windows":
380380
logger.warning("Windows Support for Local Mode is Experimental")
381381

382-
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
382+
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
383383
"""Initialize this Local SageMaker Session.
384384
385385
Args:
@@ -413,9 +413,6 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client,
413413

414414
self.config = yaml.load(open(sagemaker_config_file, "r"))
415415

416-
self._default_bucket = None
417-
self._desired_default_bucket_name = default_bucket
418-
419416
def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"):
420417
"""
421418

src/sagemaker/session.py

+7-30
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,7 @@ class Session(object): # pylint: disable=too-many-public-methods
7676
bucket based on a naming convention which includes the current AWS account ID.
7777
"""
7878

79-
def __init__(
80-
self,
81-
boto_session=None,
82-
sagemaker_client=None,
83-
sagemaker_runtime_client=None,
84-
default_bucket=None,
85-
):
79+
def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None):
8680
"""Initialize a SageMaker ``Session``.
8781
8882
Args:
@@ -97,23 +91,15 @@ def __init__(
9791
``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created
9892
using this ``Session`` use this client. If not provided, one will be created using
9993
this instance's ``boto_session``.
100-
default_bucket (str): The default s3 bucket to be used by this session.
101-
Ex: "sagemaker-us-west-2"
102-
10394
"""
10495
self._default_bucket = None
10596

10697
# currently is used for local_code in local mode
10798
self.config = None
10899

109-
self._initialize(
110-
boto_session=boto_session,
111-
sagemaker_client=sagemaker_client,
112-
sagemaker_runtime_client=sagemaker_runtime_client,
113-
default_bucket=default_bucket,
114-
)
100+
self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client)
115101

116-
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
102+
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
117103
"""Initialize this SageMaker Session.
118104
119105
Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
@@ -140,12 +126,6 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client,
140126

141127
prepend_user_agent(self.sagemaker_runtime_client)
142128

143-
self._default_bucket = None
144-
self._desired_default_bucket_name = default_bucket
145-
146-
# Create default bucket on session init to verify that desired name, if specified, is valid
147-
self.default_bucket()
148-
149129
self.local_mode = False
150130

151131
@property
@@ -334,14 +314,11 @@ def default_bucket(self):
334314
if self._default_bucket:
335315
return self._default_bucket
336316

337-
default_bucket = self._desired_default_bucket_name
338317
region = self.boto_session.region_name
339-
340-
if not default_bucket:
341-
account = self.boto_session.client(
342-
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
343-
).get_caller_identity()["Account"]
344-
default_bucket = "sagemaker-{}-{}".format(region, account)
318+
account = self.boto_session.client(
319+
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
320+
).get_caller_identity()["Account"]
321+
default_bucket = "sagemaker-{}-{}".format(region, account)
345322

346323
s3 = self.boto_session.resource("s3")
347324
try:

tests/integ/test_local_mode.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class LocalNoS3Session(LocalSession):
4343
def __init__(self):
4444
super(LocalSession, self).__init__()
4545

46-
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
46+
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
4747
self.boto_session = boto3.Session(region_name=DEFAULT_REGION)
4848
if self.config is None:
4949
self.config = {"local": {"local_code": True, "region_name": DEFAULT_REGION}}
@@ -53,9 +53,6 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client,
5353
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
5454
self.local_mode = True
5555

56-
self._default_bucket = None
57-
self._desired_default_bucket_name = default_bucket
58-
5956

6057
@pytest.fixture(scope="module")
6158
def mxnet_model(sagemaker_local_session, mxnet_full_version):

tests/integ/test_processing.py

-116
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414

1515
import os
1616

17-
import boto3
1817
import pytest
19-
from botocore.config import Config
20-
from sagemaker import Session
2118
from sagemaker.fw_registry import default_framework_uri
2219

2320
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor, Processor
@@ -26,35 +23,6 @@
2623
from tests.integ.kms_utils import get_or_create_kms_key
2724

2825
ROLE = "SageMakerRole"
29-
DEFAULT_REGION = "us-west-2"
30-
CUSTOM_BUCKET_PATH = "sagemaker-custom-bucket"
31-
32-
33-
@pytest.fixture(scope="module")
34-
def sagemaker_session_with_custom_bucket(
35-
boto_config, sagemaker_client_config, sagemaker_runtime_config
36-
):
37-
boto_session = (
38-
boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION)
39-
)
40-
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
41-
sagemaker_client = (
42-
boto_session.client("sagemaker", **sagemaker_client_config)
43-
if sagemaker_client_config
44-
else None
45-
)
46-
runtime_client = (
47-
boto_session.client("sagemaker-runtime", **sagemaker_runtime_config)
48-
if sagemaker_runtime_config
49-
else None
50-
)
51-
52-
return Session(
53-
boto_session=boto_session,
54-
sagemaker_client=sagemaker_client,
55-
sagemaker_runtime_client=runtime_client,
56-
default_bucket=CUSTOM_BUCKET_PATH,
57-
)
5826

5927

6028
@pytest.fixture(scope="module")
@@ -202,90 +170,6 @@ def test_sklearn_with_customizations(
202170
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}
203171

204172

205-
def test_sklearn_with_custom_default_bucket(
206-
sagemaker_session_with_custom_bucket,
207-
image_uri,
208-
sklearn_full_version,
209-
cpu_instance_type,
210-
output_kms_key,
211-
):
212-
213-
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
214-
215-
sklearn_processor = SKLearnProcessor(
216-
framework_version=sklearn_full_version,
217-
role=ROLE,
218-
command=["python3"],
219-
instance_type=cpu_instance_type,
220-
instance_count=1,
221-
volume_size_in_gb=100,
222-
volume_kms_key=None,
223-
output_kms_key=output_kms_key,
224-
max_runtime_in_seconds=3600,
225-
base_job_name="test-sklearn-with-customizations",
226-
env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"},
227-
tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}],
228-
sagemaker_session=sagemaker_session_with_custom_bucket,
229-
)
230-
231-
sklearn_processor.run(
232-
code=os.path.join(DATA_DIR, "dummy_script.py"),
233-
inputs=[
234-
ProcessingInput(
235-
source=input_file_path,
236-
destination="/opt/ml/processing/input/container/path/",
237-
input_name="dummy_input",
238-
s3_data_type="S3Prefix",
239-
s3_input_mode="File",
240-
s3_data_distribution_type="FullyReplicated",
241-
s3_compression_type="None",
242-
)
243-
],
244-
outputs=[
245-
ProcessingOutput(
246-
source="/opt/ml/processing/output/container/path/",
247-
output_name="dummy_output",
248-
s3_upload_mode="EndOfJob",
249-
)
250-
],
251-
arguments=["-v"],
252-
wait=True,
253-
logs=True,
254-
)
255-
256-
job_description = sklearn_processor.latest_job.describe()
257-
258-
assert job_description["ProcessingInputs"][0]["InputName"] == "dummy_input"
259-
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"]
260-
261-
assert job_description["ProcessingInputs"][1]["InputName"] == "code"
262-
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][1]["S3Input"]["S3Uri"]
263-
264-
assert job_description["ProcessingJobName"].startswith("test-sklearn-with-customizations")
265-
266-
assert job_description["ProcessingJobStatus"] == "Completed"
267-
268-
assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key
269-
assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output"
270-
271-
assert job_description["ProcessingResources"] == {
272-
"ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100}
273-
}
274-
275-
assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"]
276-
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
277-
"python3",
278-
"/opt/ml/processing/input/code/dummy_script.py",
279-
]
280-
assert job_description["AppSpecification"]["ImageUri"] == image_uri
281-
282-
assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}
283-
284-
assert ROLE in job_description["RoleArn"]
285-
286-
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}
287-
288-
289173
def test_sklearn_with_no_inputs_or_outputs(
290174
sagemaker_session, image_uri, sklearn_full_version, cpu_instance_type
291175
):

tests/unit/test_create_deploy_entities.py

-7
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,6 @@
3434
@pytest.fixture()
3535
def sagemaker_session():
3636
boto_mock = Mock(name="boto_session", region_name=REGION)
37-
client_mock = Mock()
38-
client_mock.get_caller_identity.return_value = {
39-
"UserId": "mock_user_id",
40-
"Account": "012345678910",
41-
"Arn": "arn:aws:iam::012345678910:user/mock-user",
42-
}
43-
boto_mock.client.return_value = client_mock
4437
ims = sagemaker.Session(boto_session=boto_mock)
4538
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
4639
return ims

tests/unit/test_default_bucket.py

+10-39
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,6 @@
2525
@pytest.fixture()
2626
def sagemaker_session():
2727
boto_mock = Mock(name="boto_session", region_name=REGION)
28-
client_mock = Mock()
29-
client_mock.get_caller_identity.return_value = {
30-
"UserId": "mock_user_id",
31-
"Account": "012345678910",
32-
"Arn": "arn:aws:iam::012345678910:user/mock-user",
33-
}
34-
boto_mock.client.return_value = client_mock
3528
boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID}
3629
ims = sagemaker.Session(boto_session=boto_mock)
3730
return ims
@@ -55,13 +48,11 @@ def test_default_already_cached(sagemaker_session):
5548
existing_default = "mydefaultbucket"
5649
sagemaker_session._default_bucket = existing_default
5750

58-
before_create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls
59-
6051
bucket_name = sagemaker_session.default_bucket()
6152

62-
after_create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls
53+
create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls
6354
assert bucket_name == existing_default
64-
assert before_create_calls == after_create_calls
55+
assert create_calls == []
6556

6657

6758
def test_default_bucket_exists(sagemaker_session):
@@ -87,42 +78,22 @@ def test_concurrent_bucket_modification(sagemaker_session):
8778
assert bucket_name == DEFAULT_BUCKET_NAME
8879

8980

90-
def test_bucket_creation_client_error():
81+
def test_bucket_creation_client_error(sagemaker_session):
9182
with pytest.raises(ClientError):
92-
boto_mock = Mock(name="boto_session", region_name=REGION)
93-
client_mock = Mock()
94-
client_mock.get_caller_identity.return_value = {
95-
"UserId": "mock_user_id",
96-
"Account": "012345678910",
97-
"Arn": "arn:aws:iam::012345678910:user/mock-user",
98-
}
99-
boto_mock.client.return_value = client_mock
100-
boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID}
101-
10283
error = ClientError(
10384
error_response={"Error": {"Code": "SomethingWrong", "Message": "message"}},
10485
operation_name="foo",
10586
)
106-
boto_mock.resource().create_bucket.side_effect = error
87+
sagemaker_session.boto_session.resource().create_bucket.side_effect = error
10788

108-
session = sagemaker.Session(boto_session=boto_mock)
109-
assert session._default_bucket is None
89+
sagemaker_session.default_bucket()
90+
assert sagemaker_session._default_bucket is None
11091

11192

112-
def test_bucket_creation_other_error():
93+
def test_bucket_creation_other_error(sagemaker_session):
11394
with pytest.raises(RuntimeError):
114-
boto_mock = Mock(name="boto_session", region_name=REGION)
115-
client_mock = Mock()
116-
client_mock.get_caller_identity.return_value = {
117-
"UserId": "mock_user_id",
118-
"Account": "012345678910",
119-
"Arn": "arn:aws:iam::012345678910:user/mock-user",
120-
}
121-
boto_mock.client.return_value = client_mock
122-
boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID}
123-
12495
error = RuntimeError()
125-
boto_mock.resource().create_bucket.side_effect = error
96+
sagemaker_session.boto_session.resource().create_bucket.side_effect = error
12697

127-
session = sagemaker.Session(boto_session=boto_mock)
128-
assert session._default_bucket is None
98+
sagemaker_session.default_bucket()
99+
assert sagemaker_session._default_bucket is None

tests/unit/test_endpoint_from_job.py

-7
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,6 @@
4343
@pytest.fixture()
4444
def sagemaker_session():
4545
boto_mock = Mock(name="boto_session", region_name=REGION)
46-
client_mock = Mock()
47-
client_mock.get_caller_identity.return_value = {
48-
"UserId": "mock_user_id",
49-
"Account": "012345678910",
50-
"Arn": "arn:aws:iam::012345678910:user/mock-user",
51-
}
52-
boto_mock.client.return_value = client_mock
5346
ims = sagemaker.Session(sagemaker_client=Mock(name="sagemaker_client"), boto_session=boto_mock)
5447
ims.sagemaker_client.describe_training_job = Mock(
5548
name="describe_training_job", return_value=TRAINING_JOB_RESPONSE

tests/unit/test_endpoint_from_model_data.py

-7
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,6 @@
3636
@pytest.fixture()
3737
def sagemaker_session():
3838
boto_mock = Mock(name="boto_session", region_name=REGION)
39-
client_mock = Mock()
40-
client_mock.get_caller_identity.return_value = {
41-
"UserId": "mock_user_id",
42-
"Account": "012345678910",
43-
"Arn": "arn:aws:iam::012345678910:user/mock-user",
44-
}
45-
boto_mock.client.return_value = client_mock
4639
ims = sagemaker.Session(sagemaker_client=Mock(name="sagemaker_client"), boto_session=boto_mock)
4740
ims.sagemaker_client.describe_model = Mock(
4841
name="describe_model", side_effect=_raise_does_not_exist_client_error

tests/unit/test_exception_on_bad_status.py

-6
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,7 @@ def get_sagemaker_session(returns_status):
2929
client_mock.describe_model_package = MagicMock(
3030
return_value={"ModelPackageStatus": returns_status}
3131
)
32-
client_mock.get_caller_identity.return_value = {
33-
"UserId": "mock_user_id",
34-
"Account": "012345678910",
35-
"Arn": "arn:aws:iam::012345678910:user/mock-user",
36-
}
3732
client_mock.describe_endpoint = MagicMock(return_value={"EndpointStatus": returns_status})
38-
boto_mock.client.return_value = client_mock
3933
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=client_mock)
4034
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
4135
return ims

tests/unit/test_local_session.py

-7
Original file line numberDiff line numberDiff line change
@@ -473,12 +473,5 @@ def test_file_input_content_type():
473473

474474
def test_local_session_is_set_to_local_mode():
475475
boto_session = Mock(region_name="us-west-2")
476-
client_mock = Mock()
477-
client_mock.get_caller_identity.return_value = {
478-
"UserId": "mock_user_id",
479-
"Account": "012345678910",
480-
"Arn": "arn:aws:iam::012345678910:user/mock-user",
481-
}
482-
boto_session.client.return_value = client_mock
483476
local_session = sagemaker.local.local_session.LocalSession(boto_session=boto_session)
484477
assert local_session.local_mode

0 commit comments

Comments
 (0)