Skip to content

feature: Allow setting S3 endpoint URL for Local Session #1359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Mar 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1713,6 +1713,7 @@ def _stage_user_code_in_s3(self):
directory=self.source_dir,
dependencies=self.dependencies,
kms_key=kms_key,
s3_resource=self.sagemaker_session.s3_resource,
)

def _model_source_dir(self):
Expand Down
19 changes: 17 additions & 2 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,14 @@ def validate_source_dir(script, directory):


def tar_and_upload_dir(
session, bucket, s3_key_prefix, script, directory=None, dependencies=None, kms_key=None
session,
bucket,
s3_key_prefix,
script,
directory=None,
dependencies=None,
kms_key=None,
s3_resource=None,
):
"""Package source files and upload a compress tar file to S3. The S3
location will be ``s3://<bucket>/s3_key_prefix/sourcedir.tar.gz``.
Expand All @@ -331,6 +338,9 @@ def tar_and_upload_dir(
copied into /opt/ml/lib
kms_key (str): Optional. KMS key ID used to upload objects to the bucket
(default: None).
s3_resource (boto3.resource("s3")): Optional. Pre-instantiated Boto3 Resource
for S3 connections, can be used to customize the configuration,
e.g. set the endpoint URL (default: None).
Returns:
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and
script name.
Expand All @@ -354,7 +364,12 @@ def tar_and_upload_dir(
else:
extra_args = None

session.resource("s3").Object(bucket, key).upload_file(tar_file, ExtraArgs=extra_args)
if s3_resource is None:
s3_resource = session.resource("s3", region_name=session.region_name)
else:
print("Using provided s3_resource")

s3_resource.Object(bucket, key).upload_file(tar_file, ExtraArgs=extra_args)
finally:
shutil.rmtree(tmp)

Expand Down
7 changes: 7 additions & 0 deletions src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
# Environment variables to be set during training
REGION_ENV_NAME = "AWS_REGION"
TRAINING_JOB_NAME_ENV_NAME = "TRAINING_JOB_NAME"
S3_ENDPOINT_URL_ENV_NAME = "S3_ENDPOINT_URL"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -139,6 +140,11 @@ def train(self, input_data_config, output_data_config, hyperparameters, job_name
REGION_ENV_NAME: self.sagemaker_session.boto_region_name,
TRAINING_JOB_NAME_ENV_NAME: job_name,
}
if self.sagemaker_session.s3_resource is not None:
training_env_vars[
S3_ENDPOINT_URL_ENV_NAME
] = self.sagemaker_session.s3_resource.meta.client._endpoint.host

compose_data = self._generate_compose_file(
"train", additional_volumes=volumes, additional_env_vars=training_env_vars
)
Expand Down Expand Up @@ -206,6 +212,7 @@ def serve(self, model_dir, environment):
"serve", additional_env_vars=environment, additional_volumes=volumes
)
compose_command = self._compose()

self.container = _HostingContainer(compose_command)
self.container.start()

Expand Down
9 changes: 8 additions & 1 deletion src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def create_training_job(
)
training_job = _LocalTrainingJob(container)
hyperparameters = kwargs["HyperParameters"] if "HyperParameters" in kwargs else {}
logger.info("Starting training job")
training_job.start(InputDataConfig, OutputDataConfig, hyperparameters, TrainingJobName)

LocalSagemakerClient._training_jobs[TrainingJobName] = training_job
Expand Down Expand Up @@ -377,7 +378,9 @@ def invoke_endpoint(
class LocalSession(Session):
"""Placeholder docstring"""

def __init__(self, boto_session=None):
def __init__(self, boto_session=None, s3_endpoint_url=None):
self.s3_endpoint_url = s3_endpoint_url

super(LocalSession, self).__init__(boto_session)

if platform.system() == "Windows":
Expand Down Expand Up @@ -407,6 +410,10 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
self.local_mode = True

if self.s3_endpoint_url is not None:
self.s3_resource = boto_session.resource("s3", endpoint_url=self.s3_endpoint_url)
self.s3_client = boto_session.client("s3", endpoint_url=self.s3_endpoint_url)

sagemaker_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml")
if os.path.exists(sagemaker_config_file):
try:
Expand Down
8 changes: 7 additions & 1 deletion src/sagemaker/multidatamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,13 @@ def __init__(
self.model = model
self.container_mode = MULTI_MODEL_CONTAINER_MODE
self.sagemaker_session = sagemaker_session or Session()
self.s3_client = self.sagemaker_session.boto_session.client("s3")

if self.sagemaker_session.s3_client is None:
self.s3_client = self.sagemaker_session.boto_session.client(
"s3", region_name=self.sagemaker_session.boto_session.region_name
)
else:
self.s3_client = self.sagemaker_session.s3_client

# Set the ``Model`` parameters if the model parameter is not specified
if not self.model:
Expand Down
39 changes: 30 additions & 9 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def __init__(
"""
self._default_bucket = None
self._default_bucket_name_override = default_bucket

# currently is used for local_code in local mode
self.s3_resource = None
self.s3_client = None
self.config = None

self._initialize(
Expand Down Expand Up @@ -199,7 +199,10 @@ def upload_data(self, path, bucket=None, key_prefix="data", extra_args=None):
key_suffix = name

bucket = bucket or self.default_bucket()
s3 = self.boto_session.resource("s3")
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_resource

for local_path, s3_key in files:
s3.Object(bucket, s3_key).upload_file(local_path, ExtraArgs=extra_args)
Expand Down Expand Up @@ -227,7 +230,11 @@ def upload_string_as_file_body(self, body, bucket, key, kms_key=None):
str: The S3 URI of the uploaded file.
The URI format is: ``s3://{bucket name}/{key}``.
"""
s3 = self.boto_session.resource("s3")
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_resource

s3_object = s3.Object(bucket_name=bucket, key=key)

if kms_key is not None:
Expand All @@ -254,7 +261,10 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):

"""
# Initialize the S3 client.
s3 = self.boto_session.client("s3")
if self.s3_client is None:
s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_client

# Initialize the variables used to loop through the contents of the S3 bucket.
keys = []
Expand Down Expand Up @@ -299,7 +309,10 @@ def read_s3_file(self, bucket, key_prefix):
str: The body of the s3 file as a string.

"""
s3 = self.boto_session.client("s3")
if self.s3_client is None:
s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_client

# Explicitly passing a None kms_key to boto3 throws a validation error.
s3_object = s3.get_object(Bucket=bucket, Key=key_prefix)
Expand All @@ -317,7 +330,10 @@ def list_s3_files(self, bucket, key_prefix):
[str]: The list of files at the S3 path.

"""
s3 = self.boto_session.resource("s3")
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
else:
s3 = self.s3_resource

s3_bucket = s3.Bucket(name=bucket)
s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all()
Expand All @@ -330,6 +346,7 @@ def default_bucket(self):
str: The name of the default bucket, which is of the form:
``sagemaker-{region}-{AWS account ID}``.
"""

if self._default_bucket:
return self._default_bucket

Expand Down Expand Up @@ -364,10 +381,14 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
already being created, no exception is raised.

"""
bucket = self.boto_session.resource("s3", region_name=region).Bucket(name=bucket_name)
if self.s3_resource is None:
s3 = self.boto_session.resource("s3", region_name=region)
else:
s3 = self.s3_resource

bucket = s3.Bucket(name=bucket_name)
if bucket.creation_date is None:
try:
s3 = self.boto_session.resource("s3", region_name=region)
if region == "us-east-1":
# 'us-east-1' cannot be specified because it is the default region:
# https://github.com/boto/boto3/issues/125
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def sagemaker_session():
boto_region_name=REGION,
config=None,
local_mode=False,
s3_resource=None,
s3_client=None,
)
session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
session._default_bucket = BUCKET_NAME
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def sagemaker_session():
boto_region_name=REGION,
config=None,
local_mode=False,
s3_resource=None,
s3_client=None,
)

describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def sagemaker_session():
boto_region_name=REGION,
config=None,
local_mode=False,
s3_client=None,
s3_resource=None,
)
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
sms.sagemaker_client.describe_training_job = Mock(
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_fm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def sagemaker_session():
region_name=REGION,
config=None,
local_mode=False,
s3_client=False,
s3_resource=False,
)
sms.boto_region_name = REGION
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def cd(path):
@pytest.fixture()
def sagemaker_session():
boto_mock = Mock(name="boto_session", region_name=REGION)
session_mock = Mock(name="sagemaker_session", boto_session=boto_mock)
session_mock = Mock(
name="sagemaker_session", boto_session=boto_mock, s3_client=None, s3_resource=None
)
session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
session_mock.expand_role = Mock(name="expand_role", return_value=ROLE)
session_mock.sagemaker_client.describe_training_job = Mock(
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def estimator(sagemaker_session):
@pytest.fixture()
def sagemaker_session():
boto_mock = Mock(name="boto_session")
mock_session = Mock(name="sagemaker_session", boto_session=boto_mock)
mock_session = Mock(
name="sagemaker_session", boto_session=boto_mock, s3_client=None, s3_resource=None
)
mock_session.expand_role = Mock(name="expand_role", return_value=ROLE)

return mock_session
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def sagemaker_session():
region_name=REGION,
config=None,
local_mode=False,
s3_client=None,
s3_resource=None,
)
sms.boto_region_name = REGION
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def sagemaker_session():
region_name=REGION,
config=None,
local_mode=False,
s3_client=None,
s3_resource=None,
)
sms.boto_region_name = REGION
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
Expand Down
9 changes: 8 additions & 1 deletion tests/unit/test_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,14 @@
@pytest.fixture()
def sagemaker_session():
boto_mock = Mock(name="boto_session", region_name=REGION)
sms = Mock(name="sagemaker_session", boto_session=boto_mock, config=None, local_mode=False)
sms = Mock(
name="sagemaker_session",
boto_session=boto_mock,
config=None,
local_mode=False,
s3_client=None,
s3_resource=None,
)
sms.boto_region_name = REGION
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
sms.sagemaker_client.describe_training_job = Mock(
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_linear_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def sagemaker_session():
region_name=REGION,
config=None,
local_mode=False,
s3_client=None,
s3_resource=None,
)
sms.boto_region_name = REGION
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
Expand Down
57 changes: 56 additions & 1 deletion tests/unit/test_local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

import pytest
import urllib3

import os
from botocore.exceptions import ClientError
from mock import Mock, patch
from tests.unit import DATA_DIR

import sagemaker

Expand All @@ -33,6 +34,10 @@
MODEL_NAME = "test-model"
PRIMARY_CONTAINER = {"ModelDataUrl": "/some/model/path", "Environment": {"env1": 1, "env2": "b"}}

ENDPOINT_URL = "http://127.0.0.1:9000"
BUCKET_NAME = "mybucket"
LS_FILES = {"Contents": [{"Key": "/data/test.csv"}]}


@patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model")
@patch("sagemaker.local.local_session.LocalSession")
Expand Down Expand Up @@ -475,3 +480,53 @@ def test_local_session_is_set_to_local_mode():
boto_session = Mock(region_name="us-west-2")
local_session = sagemaker.local.local_session.LocalSession(boto_session=boto_session)
assert local_session.local_mode


@pytest.fixture()
def sagemaker_session_custom_endpoint():

boto_session = Mock("boto_session")
resource_mock = Mock("resource")
client_mock = Mock("client")
boto_attrs = {"region_name": "us-east-1"}
boto_session.configure_mock(**boto_attrs)
boto_session.resource = Mock(name="resource", return_value=resource_mock)
boto_session.client = Mock(name="client", return_value=client_mock)

local_session = sagemaker.local.local_session.LocalSession(
boto_session=boto_session, s3_endpoint_url=ENDPOINT_URL
)

local_session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
return local_session


def test_local_session_with_custom_s3_endpoint_url(sagemaker_session_custom_endpoint):

boto_session = sagemaker_session_custom_endpoint.boto_session

boto_session.client.assert_called_with("s3", endpoint_url=ENDPOINT_URL)
boto_session.resource.assert_called_with("s3", endpoint_url=ENDPOINT_URL)

assert sagemaker_session_custom_endpoint.s3_client is not None
assert sagemaker_session_custom_endpoint.s3_resource is not None


def test_local_session_download_with_custom_s3_endpoint_url(sagemaker_session_custom_endpoint):

DOWNLOAD_DATA_TESTS_FILES_DIR = os.path.join(DATA_DIR, "download_data_tests")
sagemaker_session_custom_endpoint.s3_client.list_objects_v2 = Mock(
name="list_objects_v2", return_value=LS_FILES
)
sagemaker_session_custom_endpoint.s3_client.download_file = Mock(name="download_file")

sagemaker_session_custom_endpoint.download_data(
DOWNLOAD_DATA_TESTS_FILES_DIR, BUCKET_NAME, key_prefix="/data/test.csv"
)

sagemaker_session_custom_endpoint.s3_client.download_file.assert_called_with(
Bucket=BUCKET_NAME,
Key="/data/test.csv",
Filename="{}/{}".format(DOWNLOAD_DATA_TESTS_FILES_DIR, "test.csv"),
ExtraArgs=None,
)
Loading