From 0d74c5b9bd4017f3d4b90a5c9b16a5a580105f14 Mon Sep 17 00:00:00 2001 From: Dong Date: Fri, 20 Apr 2018 17:49:17 -0700 Subject: [PATCH 1/2] Increase num retries on cloudwatch logs client --- src/sagemaker/session.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 33aa81188f..f63d9af9ed 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -23,6 +23,7 @@ import json import six import yaml +import botocore.config from botocore.exceptions import ClientError from sagemaker.user_agent import prepend_user_agent @@ -549,7 +550,7 @@ def get_caller_identity_arn(self): role = re.sub(r'^(.+)sts::(\d+):assumed-role/(.+?)/.*$', r'\1iam::\2:role/\3', assumed_role) return role - def logs_for_job(self, job_name, wait=False, poll=5): # noqa: C901 - suppress complexity warning for this method + def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress complexity warning for this method """Display the logs for a given training job, optionally tailing them until the job is complete. If the output is a tty or a Jupyter cell, it will be color-coded based on which instance the log entry is from. @@ -569,7 +570,11 @@ def logs_for_job(self, job_name, wait=False, poll=5): # noqa: C901 - suppress c stream_names = [] # The list of log streams positions = {} # The current position in each stream, map of stream name -> position - client = self.boto_session.client('logs') + + # Increase retries allowed (from default of 4), as we don't want waiting for a training job + # to be interrupted by a transient exception. + config = botocore.config.Config(retries={'max_attempts': 15}) + client = self.boto_session.client('logs', config=config) log_group = '/aws/sagemaker/TrainingJobs' job_already_completed = True if status == 'Completed' or status == 'Failed' else False From a6675e79a47c13128048fb1f069dc4d80f2eb725 Mon Sep 17 00:00:00 2001 From: Winston Dong Date: Fri, 20 Apr 2018 17:57:49 -0700 Subject: [PATCH 2/2] Set training timeout on test_cifar to match client-side timeout --- tests/integ/test_tf_cifar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integ/test_tf_cifar.py b/tests/integ/test_tf_cifar.py index b639b5efb7..cbc479abeb 100644 --- a/tests/integ/test_tf_cifar.py +++ b/tests/integ/test_tf_cifar.py @@ -39,7 +39,7 @@ def test_cifar(sagemaker_session, tf_full_version): estimator = TensorFlow(entry_point='resnet_cifar_10.py', source_dir=script_path, role='SageMakerRole', framework_version=tf_full_version, training_steps=20, evaluation_steps=5, train_instance_count=2, train_instance_type='ml.p2.xlarge', - sagemaker_session=sagemaker_session, + sagemaker_session=sagemaker_session, train_max_run=20 * 60, base_job_name='test-cifar') inputs = estimator.sagemaker_session.upload_data(path=dataset_path, key_prefix='data/cifar10')