Skip to content

Commit 5003333

Browse files
committed
Run integ tests for pytorch against local mode and real cpu instance.
1 parent 922a7a5 commit 5003333

File tree

2 files changed

+26
-8
lines changed

2 files changed

+26
-8
lines changed

tests/integ/conftest.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
import pytest
14+
15+
16+
@pytest.fixture(scope='session', params=['local', 'ml.c4.xlarge'])
17+
def instance_type(request):
18+
return request.param

tests/integ/test_pytorch_train.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,22 @@
2222

2323

2424
@pytest.fixture(scope='module', name='pytorch_training_job')
25-
def fixture_training_job(sagemaker_session, pytorch_full_version):
25+
def fixture_training_job(sagemaker_session, pytorch_full_version, instance_type):
2626
with timeout(minutes=15):
2727
pytorch = PyTorch(entry_point=MNIST_SCRIPT, role='SageMakerRole', framework_version=pytorch_full_version,
28-
train_instance_count=1, train_instance_type='ml.c4.xlarge',
28+
train_instance_count=1, train_instance_type=instance_type,
2929
sagemaker_session=sagemaker_session)
3030

3131
pytorch.fit({'training': _upload_training_data(pytorch)})
3232
return pytorch.latest_training_job.name
3333

3434

35-
def test_sync_fit(sagemaker_session, pytorch_full_version):
35+
def test_sync_fit(sagemaker_session, pytorch_full_version, instance_type):
3636
training_job_name = ""
3737

3838
with timeout(minutes=15):
3939
pytorch = PyTorch(entry_point=MNIST_SCRIPT, role='SageMakerRole', framework_version=pytorch_full_version,
40-
train_instance_count=1, train_instance_type='ml.c4.xlarge',
40+
train_instance_count=1, train_instance_type=instance_type,
4141
sagemaker_session=sagemaker_session)
4242

4343
pytorch.fit({'training': _upload_training_data(pytorch)})
@@ -47,12 +47,12 @@ def test_sync_fit(sagemaker_session, pytorch_full_version):
4747
PyTorch.attach(training_job_name, sagemaker_session=sagemaker_session)
4848

4949

50-
def test_async_fit(sagemaker_session, pytorch_full_version):
50+
def test_async_fit(sagemaker_session, pytorch_full_version, instance_type):
5151
training_job_name = ""
5252

5353
with timeout(minutes=10):
5454
pytorch = PyTorch(entry_point=MNIST_SCRIPT, role='SageMakerRole', framework_version=pytorch_full_version,
55-
train_instance_count=1, train_instance_type='ml.c4.xlarge',
55+
train_instance_count=1, train_instance_type=instance_type,
5656
sagemaker_session=sagemaker_session)
5757

5858
pytorch.fit({'training': _upload_training_data(pytorch)}, wait=False)
@@ -66,12 +66,12 @@ def test_async_fit(sagemaker_session, pytorch_full_version):
6666
PyTorch.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session)
6767

6868

69-
def test_failed_training_job(sagemaker_session, pytorch_full_version):
69+
def test_failed_training_job(sagemaker_session, pytorch_full_version, instance_type):
7070
script_path = os.path.join(MNIST_DIR, 'failure_script.py')
7171

7272
with timeout(minutes=15):
7373
pytorch = PyTorch(entry_point=script_path, role='SageMakerRole', framework_version=pytorch_full_version,
74-
train_instance_count=1, train_instance_type='ml.c4.xlarge',
74+
train_instance_count=1, train_instance_type=instance_type,
7575
sagemaker_session=sagemaker_session)
7676

7777
with pytest.raises(ValueError) as e:

0 commit comments

Comments
 (0)