Skip to content

feature: allow setting the default bucket in Session #1168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def __init__(self, boto_session=None):
if platform.system() == "Windows":
logger.warning("Windows Support for Local Mode is Experimental")

def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
"""Initialize this Local SageMaker Session.

Args:
Expand Down Expand Up @@ -413,6 +413,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):

self.config = yaml.load(open(sagemaker_config_file, "r"))

self._default_bucket = None
self._desired_default_bucket_name = default_bucket

def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"):
"""

Expand Down
37 changes: 30 additions & 7 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,13 @@ class Session(object): # pylint: disable=too-many-public-methods
bucket based on a naming convention which includes the current AWS account ID.
"""

def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None):
def __init__(
self,
boto_session=None,
sagemaker_client=None,
sagemaker_runtime_client=None,
default_bucket=None,
):
"""Initialize a SageMaker ``Session``.

Args:
Expand All @@ -91,15 +97,23 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c
``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created
using this ``Session`` use this client. If not provided, one will be created using
this instance's ``boto_session``.
default_bucket (str): The default s3 bucket to be used by this session.
Ex: "sagemaker-us-west-2"

"""
self._default_bucket = None

# currently is used for local_code in local mode
self.config = None

self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client)
self._initialize(
boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_runtime_client=sagemaker_runtime_client,
default_bucket=default_bucket,
)

def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
"""Initialize this SageMaker Session.

Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
Expand All @@ -126,6 +140,12 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):

prepend_user_agent(self.sagemaker_runtime_client)

self._default_bucket = None
self._desired_default_bucket_name = default_bucket

# Create default bucket on session init to verify that desired name, if specified, is valid
self.default_bucket()

self.local_mode = False

@property
Expand Down Expand Up @@ -314,11 +334,14 @@ def default_bucket(self):
if self._default_bucket:
return self._default_bucket

default_bucket = self._desired_default_bucket_name
region = self.boto_session.region_name
account = self.boto_session.client(
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
).get_caller_identity()["Account"]
default_bucket = "sagemaker-{}-{}".format(region, account)

if not default_bucket:
account = self.boto_session.client(
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
).get_caller_identity()["Account"]
default_bucket = "sagemaker-{}-{}".format(region, account)

s3 = self.boto_session.resource("s3")
try:
Expand Down
5 changes: 4 additions & 1 deletion tests/integ/test_local_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class LocalNoS3Session(LocalSession):
def __init__(self):
super(LocalSession, self).__init__()

def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
self.boto_session = boto3.Session(region_name=DEFAULT_REGION)
if self.config is None:
self.config = {"local": {"local_code": True, "region_name": DEFAULT_REGION}}
Expand All @@ -53,6 +53,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
self.local_mode = True

self._default_bucket = None
self._desired_default_bucket_name = default_bucket


@pytest.fixture(scope="module")
def mxnet_model(sagemaker_local_session, mxnet_full_version):
Expand Down
116 changes: 116 additions & 0 deletions tests/integ/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

import os

import boto3
import pytest
from botocore.config import Config
from sagemaker import Session
from sagemaker.fw_registry import default_framework_uri

from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor, Processor
Expand All @@ -23,6 +26,35 @@
from tests.integ.kms_utils import get_or_create_kms_key

ROLE = "SageMakerRole"
DEFAULT_REGION = "us-west-2"
CUSTOM_BUCKET_PATH = "sagemaker-custom-bucket"


@pytest.fixture(scope="module")
def sagemaker_session_with_custom_bucket(
boto_config, sagemaker_client_config, sagemaker_runtime_config
):
boto_session = (
boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION)
)
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
sagemaker_client = (
boto_session.client("sagemaker", **sagemaker_client_config)
if sagemaker_client_config
else None
)
runtime_client = (
boto_session.client("sagemaker-runtime", **sagemaker_runtime_config)
if sagemaker_runtime_config
else None
)

return Session(
boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_runtime_client=runtime_client,
default_bucket=CUSTOM_BUCKET_PATH,
)


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -170,6 +202,90 @@ def test_sklearn_with_customizations(
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}


def test_sklearn_with_custom_default_bucket(
sagemaker_session_with_custom_bucket,
image_uri,
sklearn_full_version,
cpu_instance_type,
output_kms_key,
):

input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")

sklearn_processor = SKLearnProcessor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO: since this functionality lives in the parent Processor and propagates down to child classes, the test should use the Processor class rather than SKLearnProcessor. (If we validate using Processor, we validate for all child classes, but if we validate using SKLearnProcessor, we don't validate for parent classes.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see it the other way around.
I prefer to use the implementation in this case (SKLearnProcessor), because it allows us to validate the superclass inherited methods (normalization of inputs) AS WELL as the custom implementation logic (automatic code upload).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ ah, I realize I should've said ScriptProcessor, not Processor -- if another child class of ScriptProcessor is added, and if the integ tests use ScriptProcessor, automatic code upload would be tested for that new class by dint of it being tested with ScriptProcessor. (For example: what if SKLearnProcessor has some code added to it that lets this test pass just for SKLearnProcessor?)

framework_version=sklearn_full_version,
role=ROLE,
command=["python3"],
instance_type=cpu_instance_type,
instance_count=1,
volume_size_in_gb=100,
volume_kms_key=None,
output_kms_key=output_kms_key,
max_runtime_in_seconds=3600,
base_job_name="test-sklearn-with-customizations",
env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"},
tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}],
sagemaker_session=sagemaker_session_with_custom_bucket,
)

sklearn_processor.run(
code=os.path.join(DATA_DIR, "dummy_script.py"),
inputs=[
ProcessingInput(
source=input_file_path,
destination="/opt/ml/processing/input/container/path/",
input_name="dummy_input",
s3_data_type="S3Prefix",
s3_input_mode="File",
s3_data_distribution_type="FullyReplicated",
s3_compression_type="None",
)
],
outputs=[
ProcessingOutput(
source="/opt/ml/processing/output/container/path/",
output_name="dummy_output",
s3_upload_mode="EndOfJob",
)
],
arguments=["-v"],
wait=True,
logs=True,
)

job_description = sklearn_processor.latest_job.describe()

assert job_description["ProcessingInputs"][0]["InputName"] == "dummy_input"
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"]

assert job_description["ProcessingInputs"][1]["InputName"] == "code"
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][1]["S3Input"]["S3Uri"]

assert job_description["ProcessingJobName"].startswith("test-sklearn-with-customizations")

assert job_description["ProcessingJobStatus"] == "Completed"

assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key
assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output"

assert job_description["ProcessingResources"] == {
"ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100}
}

assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"]
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
"python3",
"/opt/ml/processing/input/code/dummy_script.py",
]
assert job_description["AppSpecification"]["ImageUri"] == image_uri

assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}

assert ROLE in job_description["RoleArn"]

assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}


def test_sklearn_with_no_inputs_or_outputs(
sagemaker_session, image_uri, sklearn_full_version, cpu_instance_type
):
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/test_create_deploy_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
@pytest.fixture()
def sagemaker_session():
boto_mock = Mock(name="boto_session", region_name=REGION)
client_mock = Mock()
client_mock.get_caller_identity.return_value = {
"UserId": "mock_user_id",
"Account": "012345678910",
"Arn": "arn:aws:iam::012345678910:user/mock-user",
}
boto_mock.client.return_value = client_mock
ims = sagemaker.Session(boto_session=boto_mock)
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
return ims
Expand Down
49 changes: 39 additions & 10 deletions tests/unit/test_default_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
@pytest.fixture()
def sagemaker_session():
boto_mock = Mock(name="boto_session", region_name=REGION)
client_mock = Mock()
client_mock.get_caller_identity.return_value = {
"UserId": "mock_user_id",
"Account": "012345678910",
"Arn": "arn:aws:iam::012345678910:user/mock-user",
}
boto_mock.client.return_value = client_mock
boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID}
ims = sagemaker.Session(boto_session=boto_mock)
return ims
Expand All @@ -48,11 +55,13 @@ def test_default_already_cached(sagemaker_session):
existing_default = "mydefaultbucket"
sagemaker_session._default_bucket = existing_default

before_create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls

bucket_name = sagemaker_session.default_bucket()

create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls
after_create_calls = sagemaker_session.boto_session.resource().create_bucket.mock_calls
assert bucket_name == existing_default
assert create_calls == []
assert before_create_calls == after_create_calls


def test_default_bucket_exists(sagemaker_session):
Expand All @@ -78,22 +87,42 @@ def test_concurrent_bucket_modification(sagemaker_session):
assert bucket_name == DEFAULT_BUCKET_NAME


def test_bucket_creation_client_error(sagemaker_session):
def test_bucket_creation_client_error():
with pytest.raises(ClientError):
boto_mock = Mock(name="boto_session", region_name=REGION)
client_mock = Mock()
client_mock.get_caller_identity.return_value = {
"UserId": "mock_user_id",
"Account": "012345678910",
"Arn": "arn:aws:iam::012345678910:user/mock-user",
}
boto_mock.client.return_value = client_mock
boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID}

error = ClientError(
error_response={"Error": {"Code": "SomethingWrong", "Message": "message"}},
operation_name="foo",
)
sagemaker_session.boto_session.resource().create_bucket.side_effect = error
boto_mock.resource().create_bucket.side_effect = error

sagemaker_session.default_bucket()
assert sagemaker_session._default_bucket is None
session = sagemaker.Session(boto_session=boto_mock)
assert session._default_bucket is None


def test_bucket_creation_other_error(sagemaker_session):
def test_bucket_creation_other_error():
with pytest.raises(RuntimeError):
boto_mock = Mock(name="boto_session", region_name=REGION)
client_mock = Mock()
client_mock.get_caller_identity.return_value = {
"UserId": "mock_user_id",
"Account": "012345678910",
"Arn": "arn:aws:iam::012345678910:user/mock-user",
}
boto_mock.client.return_value = client_mock
boto_mock.client("sts").get_caller_identity.return_value = {"Account": ACCOUNT_ID}

error = RuntimeError()
sagemaker_session.boto_session.resource().create_bucket.side_effect = error
boto_mock.resource().create_bucket.side_effect = error

sagemaker_session.default_bucket()
assert sagemaker_session._default_bucket is None
session = sagemaker.Session(boto_session=boto_mock)
assert session._default_bucket is None
7 changes: 7 additions & 0 deletions tests/unit/test_endpoint_from_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@
@pytest.fixture()
def sagemaker_session():
boto_mock = Mock(name="boto_session", region_name=REGION)
client_mock = Mock()
client_mock.get_caller_identity.return_value = {
"UserId": "mock_user_id",
"Account": "012345678910",
"Arn": "arn:aws:iam::012345678910:user/mock-user",
}
boto_mock.client.return_value = client_mock
ims = sagemaker.Session(sagemaker_client=Mock(name="sagemaker_client"), boto_session=boto_mock)
ims.sagemaker_client.describe_training_job = Mock(
name="describe_training_job", return_value=TRAINING_JOB_RESPONSE
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/test_endpoint_from_model_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@
@pytest.fixture()
def sagemaker_session():
boto_mock = Mock(name="boto_session", region_name=REGION)
client_mock = Mock()
client_mock.get_caller_identity.return_value = {
"UserId": "mock_user_id",
"Account": "012345678910",
"Arn": "arn:aws:iam::012345678910:user/mock-user",
}
boto_mock.client.return_value = client_mock
ims = sagemaker.Session(sagemaker_client=Mock(name="sagemaker_client"), boto_session=boto_mock)
ims.sagemaker_client.describe_model = Mock(
name="describe_model", side_effect=_raise_does_not_exist_client_error
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/test_exception_on_bad_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ def get_sagemaker_session(returns_status):
client_mock.describe_model_package = MagicMock(
return_value={"ModelPackageStatus": returns_status}
)
client_mock.get_caller_identity.return_value = {
"UserId": "mock_user_id",
"Account": "012345678910",
"Arn": "arn:aws:iam::012345678910:user/mock-user",
}
client_mock.describe_endpoint = MagicMock(return_value={"EndpointStatus": returns_status})
boto_mock.client.return_value = client_mock
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=client_mock)
ims.expand_role = Mock(return_value=EXPANDED_ROLE)
return ims
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/test_local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,5 +473,12 @@ def test_file_input_content_type():

def test_local_session_is_set_to_local_mode():
boto_session = Mock(region_name="us-west-2")
client_mock = Mock()
client_mock.get_caller_identity.return_value = {
"UserId": "mock_user_id",
"Account": "012345678910",
"Arn": "arn:aws:iam::012345678910:user/mock-user",
}
boto_session.client.return_value = client_mock
local_session = sagemaker.local.local_session.LocalSession(boto_session=boto_session)
assert local_session.local_mode
Loading