Skip to content

Commit 37b3b06

Browse files
committed
Use 'AWS_REGION' again
1 parent 0f6d86d commit 37b3b06

File tree

6 files changed

+44
-33
lines changed

6 files changed

+44
-33
lines changed

src/tf_container/experiment_trainer.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2-
#
2+
#
33
# Licensed under the Apache License, Version 2.0 (the "License").
44
# You may not use this file except in compliance with the License.
55
# A copy of the License is located at
6-
#
6+
#
77
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# or in the "license" file accompanying this file. This file is distributed
10-
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11-
# express or implied. See the License for the specific language governing
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
13-
import inspect
1413

14+
import boto3
15+
import inspect
16+
import os
1517
import tensorflow as tf
18+
import tf_container.s3_fs as s3_fs
19+
from tf_container.run import logger
1620
from tensorflow.contrib.learn import RunConfig, Experiment
1721
from tensorflow.contrib.learn.python.learn import learn_runner
1822
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
1923
from tensorflow.contrib.training import HParams
2024

21-
import tf_container.s3_fs as s3_fs
22-
from tf_container.run import logger
23-
2425

2526
class Trainer(object):
2627
DEFAULT_TRAINING_CHANNEL = 'training'
@@ -72,7 +73,7 @@ def __init__(self,
7273
self.customer_params = customer_params
7374

7475
if model_path.startswith('s3://'):
75-
s3_fs.configure_s3_fs(model_path, region_name=customer_params.get('sagemaker_region'))
76+
s3_fs.configure_s3_fs(model_path)
7677

7778
def _get_task_type(self, masters):
7879
if self.current_host in masters:

src/tf_container/s3_fs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import container_support as cs
44

55

6-
def configure_s3_fs(checkpoint_path, region_name=None):
6+
def configure_s3_fs(checkpoint_path):
7+
# If env variable is not set, defaults to None, which will use the global endpoint.
8+
region_name = os.environ.get('AWS_REGION')
79
s3 = boto3.client('s3', region_name=region_name)
810

911
# We get the AWS region of the checkpoint bucket, which may be different from

src/tf_container/serve.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,26 +10,27 @@
1010
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
1111
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
13-
import csv
13+
1414
import json
15-
import os
1615
import shutil
1716
import subprocess
18-
import time
19-
2017
import boto3
18+
import container_support as cs
2119
import google.protobuf.json_format as json_format
20+
import os
21+
2222
from grpc import StatusCode
2323
from grpc.framework.interfaces.face.face import AbortionError
24-
from six import StringIO
2524
from tensorflow.core.framework import tensor_pb2
26-
27-
import container_support as cs
25+
from tf_container import proxy_client
26+
from six import StringIO
27+
import csv
2828
from container_support.serving import UnsupportedContentTypeError, UnsupportedAcceptTypeError, \
2929
JSON_CONTENT_TYPE, CSV_CONTENT_TYPE, \
3030
OCTET_STREAM_CONTENT_TYPE, ANY_CONTENT_TYPE
31-
from tf_container import proxy_client
3231
from tf_container.run import logger
32+
import time
33+
3334

3435
TF_SERVING_PORT = 9000
3536
GENERIC_MODEL_NAME = "generic_model"

src/tf_container/trainer.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2-
#
2+
#
33
# Licensed under the Apache License, Version 2.0 (the "License").
44
# You may not use this file except in compliance with the License.
55
# A copy of the License is located at
6-
#
6+
#
77
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# or in the "license" file accompanying this file. This file is distributed
10-
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11-
# express or implied. See the License for the specific language governing
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
13-
import inspect
1413

14+
import boto3
15+
import inspect
16+
import os
1517
import tensorflow as tf
16-
1718
from tf_container.run import logger
1819
import tf_container.s3_fs as s3_fs
1920

@@ -61,7 +62,7 @@ def __init__(self,
6162
self.customer_params = customer_params
6263

6364
if model_path.startswith('s3://'):
64-
s3_fs.configure_s3_fs(model_path, region_name=customer_params.get('sagemaker_region'))
65+
s3_fs.configure_s3_fs(model_path)
6566

6667
def train(self):
6768
run_config = self._build_run_config()

test/unit/test_experiment_trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
1111
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
13+
import os
1314

1415
import pytest
1516
from mock import patch, call, MagicMock, ANY
17+
1618
from test.unit.utils import mock_import_modules
1719

1820
mock_script = {}
@@ -113,12 +115,12 @@ def test_build_tf_config_with_multiple_hosts(trainer):
113115
@patch('os.environ')
114116
def test_configure_s3_file_system(os_env, botocore, boto_client, trainer):
115117
region = 'my-region'
118+
os_env.get.return_value = region
116119

117120
trainer.Trainer(customer_script=mock_script,
118121
current_host=current_host,
119122
hosts=hosts,
120-
model_path='s3://my/s3/path',
121-
customer_params={'sagemaker_region': region})
123+
model_path='s3://my/s3/path')
122124

123125
boto_client.assert_called_once_with('s3', region_name=region)
124126
boto_client('s3', region_name=region).get_bucket_location.assert_called_once_with(Bucket='my')
@@ -129,6 +131,7 @@ def test_configure_s3_file_system(os_env, botocore, boto_client, trainer):
129131
]
130132

131133
os_env.__setitem__.assert_has_calls(calls, any_order=True)
134+
os_env.get.assert_called_with('AWS_REGION')
132135

133136

134137
@patch('boto3.client')

test/unit/test_trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
# on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
1111
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
13+
import os
1314

1415
import pytest
1516
from mock import patch, call, MagicMock, ANY
17+
1618
from test.unit.utils import mock_import_modules
1719

1820

@@ -393,12 +395,12 @@ def test_build_tf_config_with_multiple_hosts(trainer):
393395
@patch('os.environ')
394396
def test_configure_s3_file_system(os_env, botocore, boto_client, trainer_module):
395397
region = 'my-region'
398+
os_env.get.return_value = region
396399

397400
trainer_module.Trainer(customer_script=MOCK_SCRIPT,
398401
current_host=CURRENT_HOST,
399402
hosts=HOSTS,
400-
model_path='s3://my/s3/path',
401-
customer_params={'sagemaker_region': region})
403+
model_path='s3://my/s3/path')
402404

403405
boto_client.assert_called_once_with('s3', region_name=region)
404406
boto_client('s3', region_name=region).get_bucket_location.assert_called_once_with(Bucket='my')
@@ -409,6 +411,7 @@ def test_configure_s3_file_system(os_env, botocore, boto_client, trainer_module)
409411
]
410412

411413
os_env.__setitem__.assert_has_calls(calls, any_order=False)
414+
os_env.get.assert_called_with('AWS_REGION')
412415

413416

414417
CUSTOMER_PARAMS = HYPERPARAMETERS.copy()

0 commit comments

Comments
 (0)