Skip to content

feature: Allow setting S3 endpoint URL for Local Session #1359

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Mar 26, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,7 @@ def __init__(
checkpoint_s3_uri=None,
checkpoint_local_path=None,
enable_sagemaker_metrics=None,
s3_client=None,
**kwargs
):
"""Base class initializer. Subclasses which override ``__init__`` should
Expand Down Expand Up @@ -1566,6 +1567,8 @@ def __init__(
Series. For more information see:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
(default: ``None``).
s3_client (boto3.client('s3')): Optional. Pre-instantiated Boto3 Client for
S3 connections, can be used to set e.g. the endpoint URL. (default: None).
**kwargs: Additional kwargs passed to the ``EstimatorBase``
constructor.

Expand All @@ -1581,6 +1584,7 @@ def __init__(
entry_point
)
)
self.sagemaker_session.s3_client = s3_client
self.entry_point = entry_point
self.git_config = git_config
self.source_dir = source_dir
Expand Down Expand Up @@ -1708,6 +1712,7 @@ def _stage_user_code_in_s3(self):
directory=self.source_dir,
dependencies=self.dependencies,
kms_key=kms_key,
s3_client=self.sagemaker_session.s3_client,
)

def _model_source_dir(self):
Expand Down
18 changes: 16 additions & 2 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,14 @@ def validate_source_dir(script, directory):


def tar_and_upload_dir(
session, bucket, s3_key_prefix, script, directory=None, dependencies=None, kms_key=None
session,
bucket,
s3_key_prefix,
script,
directory=None,
dependencies=None,
kms_key=None,
s3_client=None,
):
"""Package source files and upload a compress tar file to S3. The S3
location will be ``s3://<bucket>/s3_key_prefix/sourcedir.tar.gz``.
Expand All @@ -331,6 +338,8 @@ def tar_and_upload_dir(
copied into /opt/ml/lib
kms_key (str): Optional. KMS key ID used to upload objects to the bucket
(default: None).
s3_client (boto3.client('s3')): Optional. Pre-instantiated Boto3 Client for S3 connections,
can be used to set e.g. the endpoint URL (default: None).
Returns:
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and
script name.
Expand All @@ -354,7 +363,12 @@ def tar_and_upload_dir(
else:
extra_args = None

session.resource("s3").Object(bucket, key).upload_file(tar_file, ExtraArgs=extra_args)
if s3_client is None:
s3_client = session.resource("s3")
else:
print("Using provided s3_client")

s3_client.Object(bucket, key).upload_file(tar_file, ExtraArgs=extra_args)
finally:
shutil.rmtree(tmp)

Expand Down
10 changes: 10 additions & 0 deletions src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
# Environment variables to be set during training
REGION_ENV_NAME = "AWS_REGION"
TRAINING_JOB_NAME_ENV_NAME = "TRAINING_JOB_NAME"
S3_ENDPOINT_URL_ENV_NAME = "S3_ENDPOINT_URL"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -139,11 +140,18 @@ def train(self, input_data_config, output_data_config, hyperparameters, job_name
REGION_ENV_NAME: self.sagemaker_session.boto_region_name,
TRAINING_JOB_NAME_ENV_NAME: job_name,
}
if self.sagemaker_session.s3_client is not None:
training_env_vars[
S3_ENDPOINT_URL_ENV_NAME
] = self.sagemaker_session.s3_client.meta.client._endpoint.host

compose_data = self._generate_compose_file(
"train", additional_volumes=volumes, additional_env_vars=training_env_vars
)
compose_command = self._compose()

logger.info("Trying to launch image: %s", str(self.image))

if _ecr_login_if_needed(self.sagemaker_session.boto_session, self.image):
_pull_image(self.image)

Expand Down Expand Up @@ -206,6 +214,7 @@ def serve(self, model_dir, environment):
"serve", additional_env_vars=environment, additional_volumes=volumes
)
compose_command = self._compose()
logger.info("Compose command: %s", compose_command)
self.container = _HostingContainer(compose_command)
self.container.start()

Expand Down Expand Up @@ -543,6 +552,7 @@ def _create_tmp_folder(self):
)
if root_dir:
root_dir = os.path.abspath(root_dir)
logger.info("Using %s for container temp files.", root_dir)

working_dir = tempfile.mkdtemp(dir=root_dir)

Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def create_training_job(
)
training_job = _LocalTrainingJob(container)
hyperparameters = kwargs["HyperParameters"] if "HyperParameters" in kwargs else {}
logger.info("Starting training job")
training_job.start(InputDataConfig, OutputDataConfig, hyperparameters, TrainingJobName)

LocalSagemakerClient._training_jobs[TrainingJobName] = training_job
Expand Down Expand Up @@ -377,8 +378,8 @@ def invoke_endpoint(
class LocalSession(Session):
"""Placeholder docstring"""

def __init__(self, boto_session=None):
super(LocalSession, self).__init__(boto_session)
def __init__(self, boto_session=None, s3_client=None):
super(LocalSession, self).__init__(boto_session, s3_client=s3_client)

if platform.system() == "Windows":
logger.warning("Windows Support for Local Mode is Experimental")
Expand Down
25 changes: 21 additions & 4 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import botocore.config
from botocore.exceptions import ClientError
import six
import yaml

import sagemaker.logs
from sagemaker import vpc_utils
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(
sagemaker_client=None,
sagemaker_runtime_client=None,
default_bucket=None,
s3_client=None,
):
"""Initialize a SageMaker ``Session``.

Expand All @@ -103,13 +105,20 @@ def __init__(
If not provided, a default bucket will be created based on the following format:
"sagemaker-{region}-{aws-account-id}".
Example: "sagemaker-my-custom-bucket".
s3_client (boto3.client('s3')): Optional. Pre-instantiated Boto3 Client for S3
connections, can be used to set e.g. the endpoint URL (default: None).

"""
self._default_bucket = None
self._default_bucket_name_override = default_bucket
self.s3_client = s3_client

# currently is used for local_code in local mode
self.config = None
sagemaker_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml")
print("Looking for config file: {}".format(sagemaker_config_file))
if os.path.exists(sagemaker_config_file):
self.config = yaml.load(open(sagemaker_config_file, "r"))
else:
self.config = None

self._initialize(
boto_session=boto_session,
Expand Down Expand Up @@ -199,7 +208,10 @@ def upload_data(self, path, bucket=None, key_prefix="data", extra_args=None):
key_suffix = name

bucket = bucket or self.default_bucket()
s3 = self.boto_session.resource("s3")
if self.s3_client is None:
s3 = self.boto_session.resource("s3")
else:
s3 = self.s3_client

for local_path, s3_key in files:
s3.Object(bucket, s3_key).upload_file(local_path, ExtraArgs=extra_args)
Expand Down Expand Up @@ -330,6 +342,7 @@ def default_bucket(self):
str: The name of the default bucket, which is of the form:
``sagemaker-{region}-{AWS account ID}``.
"""

if self._default_bucket:
return self._default_bucket

Expand Down Expand Up @@ -367,7 +380,11 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
bucket = self.boto_session.resource("s3", region_name=region).Bucket(name=bucket_name)
if bucket.creation_date is None:
try:
s3 = self.boto_session.resource("s3", region_name=region)
if self.s3_client is None:
s3 = self.boto_session.resource("s3", region_name=region)
else:
s3 = self.s3_client

if region == "us-east-1":
# 'us-east-1' cannot be specified because it is the default region:
# https://github.com/boto/boto3/issues/125
Expand Down