Skip to content

Commit 39fb741

Browse files
authored
Allow training without S3 (#39)
This makes the container not download the training script if it is not an S3 uri. In this case we assume it was mounted to the container directly. Also fix the model collection when the checkpoints happen in local files.
1 parent c8e0494 commit 39fb741

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

src/tf_container/serve.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,25 @@ def export_saved_model(checkpoint_dir, model_path, s3=boto3.client('s3')):
7575
s3.download_file(bucket_name, key, target)
7676
else:
7777
if os.path.exists(checkpoint_dir):
78-
shutil.copy2(checkpoint_dir, model_path)
78+
_recursive_copy(checkpoint_dir, model_path)
7979
else:
8080
logger.error("Failed to copy saved model. File does not exist in {}".format(checkpoint_dir))
8181

8282

83+
def _recursive_copy(src, dst):
84+
for root, dirs, files in os.walk(src):
85+
root = os.path.relpath(root, src)
86+
current_path = os.path.join(src, root)
87+
target_path = os.path.join(dst, root)
88+
89+
for file in files:
90+
shutil.copy(os.path.join(current_path, file), os.path.join(target_path, file))
91+
for dir in dirs:
92+
new_dir = os.path.join(target_path, dir)
93+
if not os.path.exists(new_dir):
94+
os.mkdir(os.path.join(target_path, dir))
95+
96+
8397
def transformer(user_module):
8498
grpc_proxy_client = proxy_client.GRPCProxyClient(TF_SERVING_PORT)
8599
_wait_model_to_load(grpc_proxy_client, TF_SERVING_MAXIMUM_LOAD_MODEL_TIME_IN_SECONDS)

src/tf_container/train_entry_point.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def train():
136136
# saving checkpoints of larger sizes.
137137
os.environ['S3_REQUEST_TIMEOUT_MSEC'] = str(env.hyperparameters.get('s3_checkpoint_save_timeout', 60000))
138138

139-
env.download_user_module()
139+
if env.user_script_archive.lower().startswith('s3://'):
140+
env.download_user_module()
140141
env.pip_install_requirements()
141142

142143
customer_script = env.import_user_module()

test/unit/test_serve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def test_export_saved_model_from_filesystem(mock_exists, mock_makedirs, serve):
8080
checkpoint_dir = 'a/dir'
8181
model_path = 'possible/another/dir'
8282

83-
with patch('shutil.copy2') as mock_copy:
83+
with patch('tf_container.serve._recursive_copy') as mock_copy:
8484
serve.export_saved_model(checkpoint_dir, model_path)
8585
mock_copy.assert_called_once_with(checkpoint_dir, model_path)
8686

0 commit comments

Comments
 (0)