Skip to content

Commit db57952

Browse files
committed
Fix regional S3 client creation
1 parent 3dd0313 commit db57952

File tree

7 files changed

+44
-48
lines changed

7 files changed

+44
-48
lines changed

src/tf_container/experiment_trainer.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,26 @@
11
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2-
#
2+
#
33
# Licensed under the Apache License, Version 2.0 (the "License").
44
# You may not use this file except in compliance with the License.
55
# A copy of the License is located at
6-
#
6+
#
77
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# or in the "license" file accompanying this file. This file is distributed
10-
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11-
# express or implied. See the License for the specific language governing
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
13-
14-
import boto3
1513
import inspect
16-
import os
14+
1715
import tensorflow as tf
18-
import tf_container.s3_fs as s3_fs
19-
from tf_container.run import logger
2016
from tensorflow.contrib.learn import RunConfig, Experiment
2117
from tensorflow.contrib.learn.python.learn import learn_runner
2218
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
2319
from tensorflow.contrib.training import HParams
2420

21+
import tf_container.s3_fs as s3_fs
22+
from tf_container.run import logger
23+
2524

2625
class Trainer(object):
2726
DEFAULT_TRAINING_CHANNEL = 'training'
@@ -73,7 +72,7 @@ def __init__(self,
7372
self.customer_params = customer_params
7473

7574
if model_path.startswith('s3://'):
76-
s3_fs.configure_s3_fs(model_path)
75+
s3_fs.configure_s3_fs(model_path, region_name=customer_params.get('sagemaker_region'))
7776

7877
def _get_task_type(self, masters):
7978
if self.current_host in masters:

src/tf_container/s3_fs.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
import container_support as cs
44

55

6-
def configure_s3_fs(checkpoint_path):
7-
# If env variable is not set, defaults to None, which will use the global endpoint.
8-
region_name = os.environ.get('AWS_REGION')
6+
def configure_s3_fs(checkpoint_path, region_name=None):
97
s3 = boto3.client('s3', region_name=region_name)
108

119
# We get the AWS region of the checkpoint bucket, which may be different from

src/tf_container/serve.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,26 @@
1010
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
1111
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
13-
13+
import csv
1414
import json
15+
import os
1516
import shutil
1617
import subprocess
18+
import time
19+
1720
import boto3
18-
import container_support as cs
1921
import google.protobuf.json_format as json_format
20-
import os
21-
2222
from grpc import StatusCode
2323
from grpc.framework.interfaces.face.face import AbortionError
24-
from tensorflow.core.framework import tensor_pb2
25-
from tf_container import proxy_client
2624
from six import StringIO
27-
import csv
25+
from tensorflow.core.framework import tensor_pb2
26+
27+
import container_support as cs
2828
from container_support.serving import UnsupportedContentTypeError, UnsupportedAcceptTypeError, \
2929
JSON_CONTENT_TYPE, CSV_CONTENT_TYPE, \
3030
OCTET_STREAM_CONTENT_TYPE, ANY_CONTENT_TYPE
31+
from tf_container import proxy_client
3132
from tf_container.run import logger
32-
import time
33-
3433

3534
TF_SERVING_PORT = 9000
3635
GENERIC_MODEL_NAME = "generic_model"

src/tf_container/train_entry_point.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,18 @@
1010
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
1111
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
13-
1413
import argparse
1514
import json
1615
import os
1716
import subprocess
1817
import time
1918
from threading import Thread
2019

20+
import boto3
2121
import tensorflow as tf
2222

2323
import container_support as cs
2424
import tf_container.run
25-
import tf_container.s3_fs as s3_fs
2625
import tf_container.serve as serve
2726

2827
_logger = tf_container.run.get_logger()
@@ -165,7 +164,7 @@ def train():
165164

166165
# only the master should export the model at the end of the execution
167166
if checkpoint_dir != env.model_dir and train_wrapper.task_type == 'master' and train_wrapper.saves_training():
168-
serve.export_saved_model(checkpoint_dir, env.model_dir)
167+
serve.export_saved_model(checkpoint_dir, env.model_dir, s3=boto3.client('s3', region_name=env.sagemaker_region))
169168

170169
if train_wrapper.task_type != 'master':
171170
_wait_until_master_is_down(_get_master(tf_config))

src/tf_container/trainer.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2-
#
2+
#
33
# Licensed under the Apache License, Version 2.0 (the "License").
44
# You may not use this file except in compliance with the License.
55
# A copy of the License is located at
6-
#
6+
#
77
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# or in the "license" file accompanying this file. This file is distributed
10-
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11-
# express or implied. See the License for the specific language governing
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
13-
14-
import boto3
1513
import inspect
16-
import os
14+
1715
import tensorflow as tf
16+
1817
from tf_container.run import logger
1918
import tf_container.s3_fs as s3_fs
2019

@@ -62,7 +61,7 @@ def __init__(self,
6261
self.customer_params = customer_params
6362

6463
if model_path.startswith('s3://'):
65-
s3_fs.configure_s3_fs(model_path)
64+
s3_fs.configure_s3_fs(model_path, region_name=customer_params.get('sagemaker_region'))
6665

6766
def train(self):
6867
run_config = self._build_run_config()

test/unit/test_experiment_trainer.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2-
#
2+
#
33
# Licensed under the Apache License, Version 2.0 (the "License").
44
# You may not use this file except in compliance with the License.
55
# A copy of the License is located at
6-
#
6+
#
77
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# or in the "license" file accompanying this file. This file is distributed
10-
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11-
# express or implied. See the License for the specific language governing
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
1313

1414
import pytest
@@ -112,12 +112,13 @@ def test_build_tf_config_with_multiple_hosts(trainer):
112112
@patch('botocore.session.get_session')
113113
@patch('os.environ')
114114
def test_configure_s3_file_system(os_env, botocore, boto_client, trainer):
115-
region = os_env.get('AWS_REGION')
115+
region = 'my-region'
116116

117117
trainer.Trainer(customer_script=mock_script,
118118
current_host=current_host,
119119
hosts=hosts,
120-
model_path='s3://my/s3/path')
120+
model_path='s3://my/s3/path',
121+
customer_params={'sagemaker_region': region})
121122

122123
boto_client.assert_called_once_with('s3', region_name=region)
123124
boto_client('s3', region_name=region).get_bucket_location.assert_called_once_with(Bucket='my')

test/unit/test_trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def test_user_model_fn(modules, trainer):
186186
estimator = trainer._build_estimator(fake_run_config)
187187

188188
estimator_mock = modules.estimator.Estimator
189-
# Verify that _model_fn passed to Estimator correctly passes args through to user script model_fn
189+
# Verify that _model_fn passed to Estimator correctly passes args through to user script model_fn
190190
estimator_mock.assert_called_with(model_fn=ANY, params=expected_hps, config=fake_run_config)
191191
_, kwargs, = estimator_mock.call_args
192192
kwargs['model_fn'](1, 2, 3, 4)
@@ -392,12 +392,13 @@ def test_build_tf_config_with_multiple_hosts(trainer):
392392
@patch('botocore.session.get_session')
393393
@patch('os.environ')
394394
def test_configure_s3_file_system(os_env, botocore, boto_client, trainer_module):
395-
region = os_env.get('AWS_REGION')
395+
region = 'my-region'
396396

397397
trainer_module.Trainer(customer_script=MOCK_SCRIPT,
398398
current_host=CURRENT_HOST,
399399
hosts=HOSTS,
400-
model_path='s3://my/s3/path')
400+
model_path='s3://my/s3/path',
401+
customer_params={'sagemaker_region': region})
401402

402403
boto_client.assert_called_once_with('s3', region_name=region)
403404
boto_client('s3', region_name=region).get_bucket_location.assert_called_once_with(Bucket='my')

0 commit comments

Comments
 (0)