Skip to content

Commit d8b3012

Browse files
authored
feature: Allow setting S3 endpoint URL for Local Session (#1359)
1 parent 8f1c96b commit d8b3012

32 files changed

+219
-18
lines changed

src/sagemaker/estimator.py

+1
Original file line numberDiff line numberDiff line change
@@ -1713,6 +1713,7 @@ def _stage_user_code_in_s3(self):
17131713
directory=self.source_dir,
17141714
dependencies=self.dependencies,
17151715
kms_key=kms_key,
1716+
s3_resource=self.sagemaker_session.s3_resource,
17161717
)
17171718

17181719
def _model_source_dir(self):

src/sagemaker/fw_utils.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,14 @@ def validate_source_dir(script, directory):
307307

308308

309309
def tar_and_upload_dir(
310-
session, bucket, s3_key_prefix, script, directory=None, dependencies=None, kms_key=None
310+
session,
311+
bucket,
312+
s3_key_prefix,
313+
script,
314+
directory=None,
315+
dependencies=None,
316+
kms_key=None,
317+
s3_resource=None,
311318
):
312319
"""Package source files and upload a compress tar file to S3. The S3
313320
location will be ``s3://<bucket>/s3_key_prefix/sourcedir.tar.gz``.
@@ -331,6 +338,9 @@ def tar_and_upload_dir(
331338
copied into /opt/ml/lib
332339
kms_key (str): Optional. KMS key ID used to upload objects to the bucket
333340
(default: None).
341+
s3_resource (boto3.resource("s3")): Optional. Pre-instantiated Boto3 Resource
342+
for S3 connections, can be used to customize the configuration,
343+
e.g. set the endpoint URL (default: None).
334344
Returns:
335345
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and
336346
script name.
@@ -354,7 +364,12 @@ def tar_and_upload_dir(
354364
else:
355365
extra_args = None
356366

357-
session.resource("s3").Object(bucket, key).upload_file(tar_file, ExtraArgs=extra_args)
367+
if s3_resource is None:
368+
s3_resource = session.resource("s3", region_name=session.region_name)
369+
else:
370+
print("Using provided s3_resource")
371+
372+
s3_resource.Object(bucket, key).upload_file(tar_file, ExtraArgs=extra_args)
358373
finally:
359374
shutil.rmtree(tmp)
360375

src/sagemaker/local/image.py

+7
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
# Environment variables to be set during training
4747
REGION_ENV_NAME = "AWS_REGION"
4848
TRAINING_JOB_NAME_ENV_NAME = "TRAINING_JOB_NAME"
49+
S3_ENDPOINT_URL_ENV_NAME = "S3_ENDPOINT_URL"
4950

5051
logger = logging.getLogger(__name__)
5152

@@ -139,6 +140,11 @@ def train(self, input_data_config, output_data_config, hyperparameters, job_name
139140
REGION_ENV_NAME: self.sagemaker_session.boto_region_name,
140141
TRAINING_JOB_NAME_ENV_NAME: job_name,
141142
}
143+
if self.sagemaker_session.s3_resource is not None:
144+
training_env_vars[
145+
S3_ENDPOINT_URL_ENV_NAME
146+
] = self.sagemaker_session.s3_resource.meta.client._endpoint.host
147+
142148
compose_data = self._generate_compose_file(
143149
"train", additional_volumes=volumes, additional_env_vars=training_env_vars
144150
)
@@ -206,6 +212,7 @@ def serve(self, model_dir, environment):
206212
"serve", additional_env_vars=environment, additional_volumes=volumes
207213
)
208214
compose_command = self._compose()
215+
209216
self.container = _HostingContainer(compose_command)
210217
self.container.start()
211218

src/sagemaker/local/local_session.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def create_training_job(
9898
)
9999
training_job = _LocalTrainingJob(container)
100100
hyperparameters = kwargs["HyperParameters"] if "HyperParameters" in kwargs else {}
101+
logger.info("Starting training job")
101102
training_job.start(InputDataConfig, OutputDataConfig, hyperparameters, TrainingJobName)
102103

103104
LocalSagemakerClient._training_jobs[TrainingJobName] = training_job
@@ -377,7 +378,9 @@ def invoke_endpoint(
377378
class LocalSession(Session):
378379
"""Placeholder docstring"""
379380

380-
def __init__(self, boto_session=None):
381+
def __init__(self, boto_session=None, s3_endpoint_url=None):
382+
self.s3_endpoint_url = s3_endpoint_url
383+
381384
super(LocalSession, self).__init__(boto_session)
382385

383386
if platform.system() == "Windows":
@@ -407,6 +410,10 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
407410
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
408411
self.local_mode = True
409412

413+
if self.s3_endpoint_url is not None:
414+
self.s3_resource = boto_session.resource("s3", endpoint_url=self.s3_endpoint_url)
415+
self.s3_client = boto_session.client("s3", endpoint_url=self.s3_endpoint_url)
416+
410417
sagemaker_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml")
411418
if os.path.exists(sagemaker_config_file):
412419
try:

src/sagemaker/multidatamodel.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,13 @@ def __init__(
9292
self.model = model
9393
self.container_mode = MULTI_MODEL_CONTAINER_MODE
9494
self.sagemaker_session = sagemaker_session or Session()
95-
self.s3_client = self.sagemaker_session.boto_session.client("s3")
95+
96+
if self.sagemaker_session.s3_client is None:
97+
self.s3_client = self.sagemaker_session.boto_session.client(
98+
"s3", region_name=self.sagemaker_session.boto_session.region_name
99+
)
100+
else:
101+
self.s3_client = self.sagemaker_session.s3_client
96102

97103
# Set the ``Model`` parameters if the model parameter is not specified
98104
if not self.model:

src/sagemaker/session.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ def __init__(
107107
"""
108108
self._default_bucket = None
109109
self._default_bucket_name_override = default_bucket
110-
111-
# currently is used for local_code in local mode
110+
self.s3_resource = None
111+
self.s3_client = None
112112
self.config = None
113113

114114
self._initialize(
@@ -199,7 +199,10 @@ def upload_data(self, path, bucket=None, key_prefix="data", extra_args=None):
199199
key_suffix = name
200200

201201
bucket = bucket or self.default_bucket()
202-
s3 = self.boto_session.resource("s3")
202+
if self.s3_resource is None:
203+
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
204+
else:
205+
s3 = self.s3_resource
203206

204207
for local_path, s3_key in files:
205208
s3.Object(bucket, s3_key).upload_file(local_path, ExtraArgs=extra_args)
@@ -227,7 +230,11 @@ def upload_string_as_file_body(self, body, bucket, key, kms_key=None):
227230
str: The S3 URI of the uploaded file.
228231
The URI format is: ``s3://{bucket name}/{key}``.
229232
"""
230-
s3 = self.boto_session.resource("s3")
233+
if self.s3_resource is None:
234+
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
235+
else:
236+
s3 = self.s3_resource
237+
231238
s3_object = s3.Object(bucket_name=bucket, key=key)
232239

233240
if kms_key is not None:
@@ -254,7 +261,10 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
254261
255262
"""
256263
# Initialize the S3 client.
257-
s3 = self.boto_session.client("s3")
264+
if self.s3_client is None:
265+
s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
266+
else:
267+
s3 = self.s3_client
258268

259269
# Initialize the variables used to loop through the contents of the S3 bucket.
260270
keys = []
@@ -299,7 +309,10 @@ def read_s3_file(self, bucket, key_prefix):
299309
str: The body of the s3 file as a string.
300310
301311
"""
302-
s3 = self.boto_session.client("s3")
312+
if self.s3_client is None:
313+
s3 = self.boto_session.client("s3", region_name=self.boto_region_name)
314+
else:
315+
s3 = self.s3_client
303316

304317
# Explicitly passing a None kms_key to boto3 throws a validation error.
305318
s3_object = s3.get_object(Bucket=bucket, Key=key_prefix)
@@ -317,7 +330,10 @@ def list_s3_files(self, bucket, key_prefix):
317330
[str]: The list of files at the S3 path.
318331
319332
"""
320-
s3 = self.boto_session.resource("s3")
333+
if self.s3_resource is None:
334+
s3 = self.boto_session.resource("s3", region_name=self.boto_region_name)
335+
else:
336+
s3 = self.s3_resource
321337

322338
s3_bucket = s3.Bucket(name=bucket)
323339
s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all()
@@ -330,6 +346,7 @@ def default_bucket(self):
330346
str: The name of the default bucket, which is of the form:
331347
``sagemaker-{region}-{AWS account ID}``.
332348
"""
349+
333350
if self._default_bucket:
334351
return self._default_bucket
335352

@@ -364,10 +381,14 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
364381
already being created, no exception is raised.
365382
366383
"""
367-
bucket = self.boto_session.resource("s3", region_name=region).Bucket(name=bucket_name)
384+
if self.s3_resource is None:
385+
s3 = self.boto_session.resource("s3", region_name=region)
386+
else:
387+
s3 = self.s3_resource
388+
389+
bucket = s3.Bucket(name=bucket_name)
368390
if bucket.creation_date is None:
369391
try:
370-
s3 = self.boto_session.resource("s3", region_name=region)
371392
if region == "us-east-1":
372393
# 'us-east-1' cannot be specified because it is the default region:
373394
# https://github.com/boto/boto3/issues/125

tests/unit/test_airflow.py

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def sagemaker_session():
3636
boto_region_name=REGION,
3737
config=None,
3838
local_mode=False,
39+
s3_resource=None,
40+
s3_client=None,
3941
)
4042
session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
4143
session._default_bucket = BUCKET_NAME

tests/unit/test_chainer.py

+2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def sagemaker_session():
6262
boto_region_name=REGION,
6363
config=None,
6464
local_mode=False,
65+
s3_resource=None,
66+
s3_client=None,
6567
)
6668

6769
describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}

tests/unit/test_estimator.py

+2
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ def sagemaker_session():
174174
boto_region_name=REGION,
175175
config=None,
176176
local_mode=False,
177+
s3_client=None,
178+
s3_resource=None,
177179
)
178180
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
179181
sms.sagemaker_client.describe_training_job = Mock(

tests/unit/test_fm.py

+2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def sagemaker_session():
5555
region_name=REGION,
5656
config=None,
5757
local_mode=False,
58+
s3_client=False,
59+
s3_resource=False,
5860
)
5961
sms.boto_region_name = REGION
6062
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)

tests/unit/test_fw_utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,9 @@ def cd(path):
158158
@pytest.fixture()
159159
def sagemaker_session():
160160
boto_mock = Mock(name="boto_session", region_name=REGION)
161-
session_mock = Mock(name="sagemaker_session", boto_session=boto_mock)
161+
session_mock = Mock(
162+
name="sagemaker_session", boto_session=boto_mock, s3_client=None, s3_resource=None
163+
)
162164
session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
163165
session_mock.expand_role = Mock(name="expand_role", return_value=ROLE)
164166
session_mock.sagemaker_client.describe_training_job = Mock(

tests/unit/test_job.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def estimator(sagemaker_session):
7373
@pytest.fixture()
7474
def sagemaker_session():
7575
boto_mock = Mock(name="boto_session")
76-
mock_session = Mock(name="sagemaker_session", boto_session=boto_mock)
76+
mock_session = Mock(
77+
name="sagemaker_session", boto_session=boto_mock, s3_client=None, s3_resource=None
78+
)
7779
mock_session.expand_role = Mock(name="expand_role", return_value=ROLE)
7880

7981
return mock_session

tests/unit/test_kmeans.py

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def sagemaker_session():
4949
region_name=REGION,
5050
config=None,
5151
local_mode=False,
52+
s3_client=None,
53+
s3_resource=None,
5254
)
5355
sms.boto_region_name = REGION
5456
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)

tests/unit/test_knn.py

+2
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def sagemaker_session():
5555
region_name=REGION,
5656
config=None,
5757
local_mode=False,
58+
s3_client=None,
59+
s3_resource=None,
5860
)
5961
sms.boto_region_name = REGION
6062
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)

tests/unit/test_lda.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,14 @@
3939
@pytest.fixture()
4040
def sagemaker_session():
4141
boto_mock = Mock(name="boto_session", region_name=REGION)
42-
sms = Mock(name="sagemaker_session", boto_session=boto_mock, config=None, local_mode=False)
42+
sms = Mock(
43+
name="sagemaker_session",
44+
boto_session=boto_mock,
45+
config=None,
46+
local_mode=False,
47+
s3_client=None,
48+
s3_resource=None,
49+
)
4350
sms.boto_region_name = REGION
4451
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
4552
sms.sagemaker_client.describe_training_job = Mock(

tests/unit/test_linear_learner.py

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def sagemaker_session():
5050
region_name=REGION,
5151
config=None,
5252
local_mode=False,
53+
s3_client=None,
54+
s3_resource=None,
5355
)
5456
sms.boto_region_name = REGION
5557
sms.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)

tests/unit/test_local_session.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414

1515
import pytest
1616
import urllib3
17-
17+
import os
1818
from botocore.exceptions import ClientError
1919
from mock import Mock, patch
20+
from tests.unit import DATA_DIR
2021

2122
import sagemaker
2223

@@ -33,6 +34,10 @@
3334
MODEL_NAME = "test-model"
3435
PRIMARY_CONTAINER = {"ModelDataUrl": "/some/model/path", "Environment": {"env1": 1, "env2": "b"}}
3536

37+
ENDPOINT_URL = "http://127.0.0.1:9000"
38+
BUCKET_NAME = "mybucket"
39+
LS_FILES = {"Contents": [{"Key": "/data/test.csv"}]}
40+
3641

3742
@patch("sagemaker.local.image._SageMakerContainer.train", return_value="/some/path/to/model")
3843
@patch("sagemaker.local.local_session.LocalSession")
@@ -475,3 +480,53 @@ def test_local_session_is_set_to_local_mode():
475480
boto_session = Mock(region_name="us-west-2")
476481
local_session = sagemaker.local.local_session.LocalSession(boto_session=boto_session)
477482
assert local_session.local_mode
483+
484+
485+
@pytest.fixture()
486+
def sagemaker_session_custom_endpoint():
487+
488+
boto_session = Mock("boto_session")
489+
resource_mock = Mock("resource")
490+
client_mock = Mock("client")
491+
boto_attrs = {"region_name": "us-east-1"}
492+
boto_session.configure_mock(**boto_attrs)
493+
boto_session.resource = Mock(name="resource", return_value=resource_mock)
494+
boto_session.client = Mock(name="client", return_value=client_mock)
495+
496+
local_session = sagemaker.local.local_session.LocalSession(
497+
boto_session=boto_session, s3_endpoint_url=ENDPOINT_URL
498+
)
499+
500+
local_session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
501+
return local_session
502+
503+
504+
def test_local_session_with_custom_s3_endpoint_url(sagemaker_session_custom_endpoint):
505+
506+
boto_session = sagemaker_session_custom_endpoint.boto_session
507+
508+
boto_session.client.assert_called_with("s3", endpoint_url=ENDPOINT_URL)
509+
boto_session.resource.assert_called_with("s3", endpoint_url=ENDPOINT_URL)
510+
511+
assert sagemaker_session_custom_endpoint.s3_client is not None
512+
assert sagemaker_session_custom_endpoint.s3_resource is not None
513+
514+
515+
def test_local_session_download_with_custom_s3_endpoint_url(sagemaker_session_custom_endpoint):
516+
517+
DOWNLOAD_DATA_TESTS_FILES_DIR = os.path.join(DATA_DIR, "download_data_tests")
518+
sagemaker_session_custom_endpoint.s3_client.list_objects_v2 = Mock(
519+
name="list_objects_v2", return_value=LS_FILES
520+
)
521+
sagemaker_session_custom_endpoint.s3_client.download_file = Mock(name="download_file")
522+
523+
sagemaker_session_custom_endpoint.download_data(
524+
DOWNLOAD_DATA_TESTS_FILES_DIR, BUCKET_NAME, key_prefix="/data/test.csv"
525+
)
526+
527+
sagemaker_session_custom_endpoint.s3_client.download_file.assert_called_with(
528+
Bucket=BUCKET_NAME,
529+
Key="/data/test.csv",
530+
Filename="{}/{}".format(DOWNLOAD_DATA_TESTS_FILES_DIR, "test.csv"),
531+
ExtraArgs=None,
532+
)

0 commit comments

Comments
 (0)