Skip to content

feature: allow setting the default bucket in Session #1176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 16, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def __init__(self, boto_session=None):
if platform.system() == "Windows":
logger.warning("Windows Support for Local Mode is Experimental")

def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
"""Initialize this Local SageMaker Session.

Args:
Expand Down Expand Up @@ -413,6 +413,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):

self.config = yaml.load(open(sagemaker_config_file, "r"))

self._default_bucket = None
self._desired_default_bucket_name = default_bucket

def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"):
"""

Expand Down
34 changes: 27 additions & 7 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,13 @@ class Session(object): # pylint: disable=too-many-public-methods
bucket based on a naming convention which includes the current AWS account ID.
"""

def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_client=None):
def __init__(
self,
boto_session=None,
sagemaker_client=None,
sagemaker_runtime_client=None,
default_bucket=None,
):
"""Initialize a SageMaker ``Session``.

Args:
Expand All @@ -91,15 +97,23 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c
``InvokeEndpoint`` calls to Amazon SageMaker (default: None). Predictors created
using this ``Session`` use this client. If not provided, one will be created using
this instance's ``boto_session``.
default_bucket (str): The default s3 bucket to be used by this session.
Ex: "sagemaker-us-west-2"

"""
self._default_bucket = None

# currently is used for local_code in local mode
self.config = None

self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client)
self._initialize(
boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_runtime_client=sagemaker_runtime_client,
default_bucket=default_bucket,
)

def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
"""Initialize this SageMaker Session.

Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
Expand All @@ -126,6 +140,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):

prepend_user_agent(self.sagemaker_runtime_client)

self._default_bucket = None
self._desired_default_bucket_name = default_bucket

self.local_mode = False

@property
Expand Down Expand Up @@ -314,11 +331,14 @@ def default_bucket(self):
if self._default_bucket:
return self._default_bucket

default_bucket = self._desired_default_bucket_name
region = self.boto_session.region_name
account = self.boto_session.client(
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
).get_caller_identity()["Account"]
default_bucket = "sagemaker-{}-{}".format(region, account)

if not default_bucket:
account = self.boto_session.client(
"sts", region_name=region, endpoint_url=sts_regional_endpoint(region)
).get_caller_identity()["Account"]
default_bucket = "sagemaker-{}-{}".format(region, account)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I'd move l. 334 to be right before the if statement because it's clearer that's it's basically just an if/else there. Or just make it an if/else:

if self._desired_default_bucket_name
    default_bucket = self._desired_default_bucket_name
else
    ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the line closer, but I like to avoid ifelses when possible.
I think it's more readable to have a default case get overridden under specific circumstances. Also guarantees that the variable always gets set.


s3 = self.boto_session.resource("s3")
try:
Expand Down
5 changes: 4 additions & 1 deletion tests/integ/test_local_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class LocalNoS3Session(LocalSession):
def __init__(self):
super(LocalSession, self).__init__()

def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client, default_bucket):
self.boto_session = boto3.Session(region_name=DEFAULT_REGION)
if self.config is None:
self.config = {"local": {"local_code": True, "region_name": DEFAULT_REGION}}
Expand All @@ -53,6 +53,9 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
self.local_mode = True

self._default_bucket = None
self._desired_default_bucket_name = default_bucket


@pytest.fixture(scope="module")
def mxnet_model(sagemaker_local_session, mxnet_full_version):
Expand Down
184 changes: 184 additions & 0 deletions tests/integ/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

import os

import boto3
import pytest
from botocore.config import Config
from sagemaker import Session
from sagemaker.fw_registry import default_framework_uri

from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor, Processor
Expand All @@ -23,6 +26,35 @@
from tests.integ.kms_utils import get_or_create_kms_key

ROLE = "SageMakerRole"
DEFAULT_REGION = "us-west-2"
CUSTOM_BUCKET_PATH = "sagemaker-custom-bucket"


@pytest.fixture(scope="module")
def sagemaker_session_with_custom_bucket(
boto_config, sagemaker_client_config, sagemaker_runtime_config
):
boto_session = (
boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION)
)
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
sagemaker_client = (
boto_session.client("sagemaker", **sagemaker_client_config)
if sagemaker_client_config
else None
)
runtime_client = (
boto_session.client("sagemaker-runtime", **sagemaker_runtime_config)
if sagemaker_runtime_config
else None
)

return Session(
boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_runtime_client=runtime_client,
default_bucket=CUSTOM_BUCKET_PATH,
)


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -170,6 +202,89 @@ def test_sklearn_with_customizations(
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}


def test_sklearn_with_custom_default_bucket(
sagemaker_session_with_custom_bucket,
image_uri,
sklearn_full_version,
cpu_instance_type,
output_kms_key,
):
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")

sklearn_processor = SKLearnProcessor(
framework_version=sklearn_full_version,
role=ROLE,
command=["python3"],
instance_type=cpu_instance_type,
instance_count=1,
volume_size_in_gb=100,
volume_kms_key=None,
output_kms_key=output_kms_key,
max_runtime_in_seconds=3600,
base_job_name="test-sklearn-with-customizations",
env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"},
tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}],
sagemaker_session=sagemaker_session_with_custom_bucket,
)

sklearn_processor.run(
code=os.path.join(DATA_DIR, "dummy_script.py"),
inputs=[
ProcessingInput(
source=input_file_path,
destination="/opt/ml/processing/input/container/path/",
input_name="dummy_input",
s3_data_type="S3Prefix",
s3_input_mode="File",
s3_data_distribution_type="FullyReplicated",
s3_compression_type="None",
)
],
outputs=[
ProcessingOutput(
source="/opt/ml/processing/output/container/path/",
output_name="dummy_output",
s3_upload_mode="EndOfJob",
)
],
arguments=["-v"],
wait=True,
logs=True,
)

job_description = sklearn_processor.latest_job.describe()

assert job_description["ProcessingInputs"][0]["InputName"] == "dummy_input"
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"]

assert job_description["ProcessingInputs"][1]["InputName"] == "code"
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][1]["S3Input"]["S3Uri"]

assert job_description["ProcessingJobName"].startswith("test-sklearn-with-customizations")

assert job_description["ProcessingJobStatus"] == "Completed"

assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key
assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output"

assert job_description["ProcessingResources"] == {
"ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100}
}

assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"]
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
"python3",
"/opt/ml/processing/input/code/dummy_script.py",
]
assert job_description["AppSpecification"]["ImageUri"] == image_uri

assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}

assert ROLE in job_description["RoleArn"]

assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}


def test_sklearn_with_no_inputs_or_outputs(
sagemaker_session, image_uri, sklearn_full_version, cpu_instance_type
):
Expand Down Expand Up @@ -405,3 +520,72 @@ def test_processor(sagemaker_session, image_uri, cpu_instance_type, output_kms_k
assert ROLE in job_description["RoleArn"]

assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}


def test_processor_with_custom_bucket(
sagemaker_session_with_custom_bucket, image_uri, cpu_instance_type, output_kms_key
):
script_path = os.path.join(DATA_DIR, "dummy_script.py")

processor = Processor(
role=ROLE,
image_uri=image_uri,
instance_count=1,
instance_type=cpu_instance_type,
entrypoint=["python3", "/opt/ml/processing/input/code/dummy_script.py"],
volume_size_in_gb=100,
volume_kms_key=None,
output_kms_key=output_kms_key,
max_runtime_in_seconds=3600,
base_job_name="test-processor",
env={"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"},
tags=[{"Key": "dummy-tag", "Value": "dummy-tag-value"}],
sagemaker_session=sagemaker_session_with_custom_bucket,
)

processor.run(
inputs=[
ProcessingInput(
source=script_path, destination="/opt/ml/processing/input/code/", input_name="code"
)
],
outputs=[
ProcessingOutput(
source="/opt/ml/processing/output/container/path/",
output_name="dummy_output",
s3_upload_mode="EndOfJob",
)
],
arguments=["-v"],
wait=True,
logs=True,
)

job_description = processor.latest_job.describe()

assert job_description["ProcessingInputs"][0]["InputName"] == "code"
assert CUSTOM_BUCKET_PATH in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"]

assert job_description["ProcessingJobName"].startswith("test-processor")

assert job_description["ProcessingJobStatus"] == "Completed"

assert job_description["ProcessingOutputConfig"]["KmsKeyId"] == output_kms_key
assert job_description["ProcessingOutputConfig"]["Outputs"][0]["OutputName"] == "dummy_output"

assert job_description["ProcessingResources"] == {
"ClusterConfig": {"InstanceCount": 1, "InstanceType": "ml.m4.xlarge", "VolumeSizeInGB": 100}
}

assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"]
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
"python3",
"/opt/ml/processing/input/code/dummy_script.py",
]
assert job_description["AppSpecification"]["ImageUri"] == image_uri

assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}

assert ROLE in job_description["RoleArn"]

assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 3600}
50 changes: 50 additions & 0 deletions tests/integ/test_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import boto3
from botocore.config import Config

from sagemaker import Session

DEFAULT_REGION = "us-west-2"
CUSTOM_BUCKET_NAME = "this-bucket-should-not-exist"


def test_sagemaker_session_does_not_create_bucket_on_init(
sagemaker_client_config, sagemaker_runtime_config, boto_config
):
boto_session = (
boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION)
)
sagemaker_client_config.setdefault("config", Config(retries=dict(max_attempts=10)))
sagemaker_client = (
boto_session.client("sagemaker", **sagemaker_client_config)
if sagemaker_client_config
else None
)
runtime_client = (
boto_session.client("sagemaker-runtime", **sagemaker_runtime_config)
if sagemaker_runtime_config
else None
)

Session(
boto_session=boto_session,
sagemaker_client=sagemaker_client,
sagemaker_runtime_client=runtime_client,
default_bucket=CUSTOM_BUCKET_NAME,
)

s3 = boto3.resource("s3")
assert s3.Bucket(CUSTOM_BUCKET_NAME).creation_date is None