Skip to content

Create regional S3 client for model exporting #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions src/tf_container/experiment_trainer.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import boto3
import inspect
import os

import tensorflow as tf
import tf_container.s3_fs as s3_fs
from tf_container.run import logger
from tensorflow.contrib.learn import RunConfig, Experiment
from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
from tensorflow.contrib.training import HParams

import tf_container.s3_fs as s3_fs
from tf_container.run import logger


class Trainer(object):
DEFAULT_TRAINING_CHANNEL = 'training'
Expand Down Expand Up @@ -73,7 +72,7 @@ def __init__(self,
self.customer_params = customer_params

if model_path.startswith('s3://'):
s3_fs.configure_s3_fs(model_path)
s3_fs.configure_s3_fs(model_path, region_name=customer_params.get('sagemaker_region'))

def _get_task_type(self, masters):
if self.current_host in masters:
Expand Down
4 changes: 1 addition & 3 deletions src/tf_container/s3_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import container_support as cs


def configure_s3_fs(checkpoint_path):
# If env variable is not set, defaults to None, which will use the global endpoint.
region_name = os.environ.get('AWS_REGION')
def configure_s3_fs(checkpoint_path, region_name=None):
s3 = boto3.client('s3', region_name=region_name)

# We get the AWS region of the checkpoint bucket, which may be different from
Expand Down
17 changes: 8 additions & 9 deletions src/tf_container/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,26 @@
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import csv
import json
import os
import shutil
import subprocess
import time

import boto3
import container_support as cs
import google.protobuf.json_format as json_format
import os

from grpc import StatusCode
from grpc.framework.interfaces.face.face import AbortionError
from tensorflow.core.framework import tensor_pb2
from tf_container import proxy_client
from six import StringIO
import csv
from tensorflow.core.framework import tensor_pb2

import container_support as cs
from container_support.serving import UnsupportedContentTypeError, UnsupportedAcceptTypeError, \
JSON_CONTENT_TYPE, CSV_CONTENT_TYPE, \
OCTET_STREAM_CONTENT_TYPE, ANY_CONTENT_TYPE
from tf_container import proxy_client
from tf_container.run import logger
import time


TF_SERVING_PORT = 9000
GENERIC_MODEL_NAME = "generic_model"
Expand Down
5 changes: 2 additions & 3 deletions src/tf_container/train_entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,18 @@
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import argparse
import json
import os
import subprocess
import time
from threading import Thread

import boto3
import tensorflow as tf

import container_support as cs
import tf_container.run
import tf_container.s3_fs as s3_fs
import tf_container.serve as serve

_logger = tf_container.run.get_logger()
Expand Down Expand Up @@ -165,7 +164,7 @@ def train():

# only the master should export the model at the end of the execution
if checkpoint_dir != env.model_dir and train_wrapper.task_type == 'master' and train_wrapper.saves_training():
serve.export_saved_model(checkpoint_dir, env.model_dir)
serve.export_saved_model(checkpoint_dir, env.model_dir, s3=boto3.client('s3', region_name=env.sagemaker_region))

if train_wrapper.task_type != 'master':
_wait_until_master_is_down(_get_master(tf_config))
19 changes: 9 additions & 10 deletions src/tf_container/trainer.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import boto3
import inspect
import os

import tensorflow as tf

from tf_container.run import logger
import tf_container.s3_fs as s3_fs

Expand Down Expand Up @@ -62,7 +61,7 @@ def __init__(self,
self.customer_params = customer_params

if model_path.startswith('s3://'):
s3_fs.configure_s3_fs(model_path)
s3_fs.configure_s3_fs(model_path, region_name=customer_params.get('sagemaker_region'))

def train(self):
run_config = self._build_run_config()
Expand Down
17 changes: 9 additions & 8 deletions test/unit/test_experiment_trainer.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import pytest
Expand Down Expand Up @@ -112,12 +112,13 @@ def test_build_tf_config_with_multiple_hosts(trainer):
@patch('botocore.session.get_session')
@patch('os.environ')
def test_configure_s3_file_system(os_env, botocore, boto_client, trainer):
region = os_env.get('AWS_REGION')
region = 'my-region'

trainer.Trainer(customer_script=mock_script,
current_host=current_host,
hosts=hosts,
model_path='s3://my/s3/path')
model_path='s3://my/s3/path',
customer_params={'sagemaker_region': region})

boto_client.assert_called_once_with('s3', region_name=region)
boto_client('s3', region_name=region).get_bucket_location.assert_called_once_with(Bucket='my')
Expand Down
7 changes: 4 additions & 3 deletions test/unit/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_user_model_fn(modules, trainer):
estimator = trainer._build_estimator(fake_run_config)

estimator_mock = modules.estimator.Estimator
# Verify that _model_fn passed to Estimator correctly passes args through to user script model_fn
# Verify that _model_fn passed to Estimator correctly passes args through to user script model_fn
estimator_mock.assert_called_with(model_fn=ANY, params=expected_hps, config=fake_run_config)
_, kwargs, = estimator_mock.call_args
kwargs['model_fn'](1, 2, 3, 4)
Expand Down Expand Up @@ -392,12 +392,13 @@ def test_build_tf_config_with_multiple_hosts(trainer):
@patch('botocore.session.get_session')
@patch('os.environ')
def test_configure_s3_file_system(os_env, botocore, boto_client, trainer_module):
region = os_env.get('AWS_REGION')
region = 'my-region'

trainer_module.Trainer(customer_script=MOCK_SCRIPT,
current_host=CURRENT_HOST,
hosts=HOSTS,
model_path='s3://my/s3/path')
model_path='s3://my/s3/path',
customer_params={'sagemaker_region': region})

boto_client.assert_called_once_with('s3', region_name=region)
boto_client('s3', region_name=region).get_bucket_location.assert_called_once_with(Bucket='my')
Expand Down