diff --git a/src/tf_container/serve.py b/src/tf_container/serve.py index 96516d46..69f0986e 100644 --- a/src/tf_container/serve.py +++ b/src/tf_container/serve.py @@ -1,14 +1,14 @@ # Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing # permissions and limitations under the License. import json @@ -50,8 +50,8 @@ def export_saved_model(checkpoint_dir, model_path, s3=boto3.client('s3')): except KeyError as e: logger.error("Failed to download saved model. File does not exist in {}".format(checkpoint_dir)) raise e - - saved_model_path = saved_model_path_array[0] + # Select most recent saved_model.pb + saved_model_path = saved_model_path_array[-1] variables_path = [x['Key'] for x in contents if 'variables/variables' in x['Key']] variable_names_to_paths = {v.split('/').pop(): v for v in variables_path}