Skip to content

Commit df19d37

Browse files
authored
Test stability - increase retries on cloudwatch log client + set training timeout for test_cifar (#159)
* Increase num retries on cloudwatch logs client * Set training timeout on test_cifar to match client-side timeout
1 parent f1e0781 commit df19d37

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

src/sagemaker/session.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import json
2424
import six
2525
import yaml
26+
import botocore.config
2627
from botocore.exceptions import ClientError
2728

2829
from sagemaker.user_agent import prepend_user_agent
@@ -549,7 +550,7 @@ def get_caller_identity_arn(self):
549550
role = re.sub(r'^(.+)sts::(\d+):assumed-role/(.+?)/.*$', r'\1iam::\2:role/\3', assumed_role)
550551
return role
551552

552-
def logs_for_job(self, job_name, wait=False, poll=5): # noqa: C901 - suppress complexity warning for this method
553+
def logs_for_job(self, job_name, wait=False, poll=10): # noqa: C901 - suppress complexity warning for this method
553554
"""Display the logs for a given training job, optionally tailing them until the
554555
job is complete. If the output is a tty or a Jupyter cell, it will be color-coded
555556
based on which instance the log entry is from.
@@ -569,7 +570,11 @@ def logs_for_job(self, job_name, wait=False, poll=5): # noqa: C901 - suppress c
569570

570571
stream_names = [] # The list of log streams
571572
positions = {} # The current position in each stream, map of stream name -> position
572-
client = self.boto_session.client('logs')
573+
574+
# Increase retries allowed (from default of 4), as we don't want waiting for a training job
575+
# to be interrupted by a transient exception.
576+
config = botocore.config.Config(retries={'max_attempts': 15})
577+
client = self.boto_session.client('logs', config=config)
573578
log_group = '/aws/sagemaker/TrainingJobs'
574579

575580
job_already_completed = True if status == 'Completed' or status == 'Failed' else False

tests/integ/test_tf_cifar.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_cifar(sagemaker_session, tf_full_version):
3939
estimator = TensorFlow(entry_point='resnet_cifar_10.py', source_dir=script_path, role='SageMakerRole',
4040
framework_version=tf_full_version, training_steps=20, evaluation_steps=5,
4141
train_instance_count=2, train_instance_type='ml.p2.xlarge',
42-
sagemaker_session=sagemaker_session,
42+
sagemaker_session=sagemaker_session, train_max_run=20 * 60,
4343
base_job_name='test-cifar')
4444

4545
inputs = estimator.sagemaker_session.upload_data(path=dataset_path, key_prefix='data/cifar10')

0 commit comments

Comments
 (0)