Skip to content

Commit 2d4cb33

Browse files
committed
Do not try to reattach to the training job when running locally since it's not yet supported.
1 parent 5003333 commit 2d4cb33

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

tests/integ/test_pytorch_train.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@ def test_sync_fit(sagemaker_session, pytorch_full_version, instance_type):
4141
sagemaker_session=sagemaker_session)
4242

4343
pytorch.fit({'training': _upload_training_data(pytorch)})
44-
training_job_name = pytorch.latest_training_job.name
44+
training_job_name = pytorch.latest_training_job.name
4545

46-
with timeout(minutes=20):
47-
PyTorch.attach(training_job_name, sagemaker_session=sagemaker_session)
46+
if not _is_local_mode(instance_type):
47+
with timeout(minutes=20):
48+
PyTorch.attach(training_job_name, sagemaker_session=sagemaker_session)
4849

4950

5051
def test_async_fit(sagemaker_session, pytorch_full_version, instance_type):
@@ -61,9 +62,10 @@ def test_async_fit(sagemaker_session, pytorch_full_version, instance_type):
6162
print("Waiting to re-attach to the training job: %s" % training_job_name)
6263
time.sleep(20)
6364

64-
with timeout(minutes=35):
65-
print("Re-attaching now to: %s" % training_job_name)
66-
PyTorch.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session)
65+
if not _is_local_mode(instance_type):
66+
with timeout(minutes=35):
67+
print("Re-attaching now to: %s" % training_job_name)
68+
PyTorch.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session)
6769

6870

6971
def test_failed_training_job(sagemaker_session, pytorch_full_version, instance_type):
@@ -81,4 +83,8 @@ def test_failed_training_job(sagemaker_session, pytorch_full_version, instance_t
8183

8284
def _upload_training_data(pytorch):
8385
return pytorch.sagemaker_session.upload_data(path=os.path.join(MNIST_DIR, 'training'),
84-
key_prefix='integ-test-data/pytorch_mnist/training')
86+
key_prefix='integ-test-data/pytorch_mnist/training')
87+
88+
89+
def _is_local_mode(instance_type):
90+
return instance_type == 'local'

0 commit comments

Comments
 (0)