Skip to content

Commit e2334a9

Browse files
authored
Merge pull request aws#4 from aws/pytorch-prediction
Add integ tests for pytorch prediction.
2 parents 58b9743 + 7646115 commit e2334a9

File tree

2 files changed

+64
-27
lines changed

2 files changed

+64
-27
lines changed

tests/data/pytorch_mnist/mnist.py

+7
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,10 @@ def test(model, test_loader, cuda):
163163
logger.debug('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
164164
test_loss, correct, len(test_loader.dataset),
165165
100. * correct / len(test_loader.dataset)))
166+
167+
168+
def model_fn(model_dir):
169+
model = torch.nn.DataParallel(Net())
170+
with open(os.path.join(model_dir, 'model'), 'rb') as f:
171+
model.load_state_dict(torch.load(f))
172+
return model

tests/integ/test_pytorch_train.py

+57-27
Original file line numberDiff line numberDiff line change
@@ -10,55 +10,71 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
import numpy
1314
import os
15+
import sys
1416
import time
1517
import pytest
1618
from sagemaker.pytorch.estimator import PyTorch
19+
from sagemaker.pytorch.model import PyTorchModel
20+
from sagemaker.utils import sagemaker_timestamp
1721
from tests.integ import DATA_DIR
18-
from tests.integ.timeout import timeout
22+
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
1923

2024
MNIST_DIR = os.path.join(DATA_DIR, 'pytorch_mnist')
2125
MNIST_SCRIPT = os.path.join(MNIST_DIR, 'mnist.py')
26+
PYTHON_VERSION = 'py' + str(sys.version_info.major)
2227

2328

2429
@pytest.fixture(scope='module', name='pytorch_training_job')
25-
def fixture_training_job(sagemaker_session, pytorch_full_version, instance_type):
30+
def fixture_training_job(sagemaker_session, pytorch_full_version):
31+
instance_type = 'ml.c4.xlarge'
2632
with timeout(minutes=15):
27-
pytorch = PyTorch(entry_point=MNIST_SCRIPT, role='SageMakerRole', framework_version=pytorch_full_version,
28-
train_instance_count=1, train_instance_type=instance_type,
29-
sagemaker_session=sagemaker_session)
33+
pytorch = _get_pytorch_estimator(sagemaker_session, pytorch_full_version, instance_type)
3034

3135
pytorch.fit({'training': _upload_training_data(pytorch)})
3236
return pytorch.latest_training_job.name
3337

3438

35-
def test_sync_fit(sagemaker_session, pytorch_full_version):
36-
training_job_name = ""
39+
def test_sync_fit_deploy(pytorch_training_job, sagemaker_session):
3740
# TODO: add tests against local mode when it's ready to be used
38-
instance_type = 'ml.p2.xlarge'
41+
endpoint_name = 'test-pytorch-sync-fit-attach-deploy{}'.format(sagemaker_timestamp())
42+
with timeout(minutes=20):
43+
estimator = PyTorch.attach(pytorch_training_job, sagemaker_session=sagemaker_session)
44+
predictor = estimator.deploy(1, 'ml.c4.xlarge', endpoint_name=endpoint_name)
45+
data = numpy.zeros(shape=(1, 1, 28, 28))
46+
predictor.predict(data)
3947

40-
with timeout(minutes=15):
41-
pytorch = PyTorch(entry_point=MNIST_SCRIPT, role='SageMakerRole', framework_version=pytorch_full_version,
42-
train_instance_count=1, train_instance_type=instance_type,
43-
sagemaker_session=sagemaker_session)
48+
batch_size = 100
49+
data = numpy.random.rand(batch_size, 1, 28, 28)
50+
output = predictor.predict(data)
4451

45-
pytorch.fit({'training': _upload_training_data(pytorch)})
46-
training_job_name = pytorch.latest_training_job.name
52+
assert numpy.asarray(output).shape == (batch_size, 10)
4753

48-
if not _is_local_mode(instance_type):
49-
with timeout(minutes=20):
50-
PyTorch.attach(training_job_name, sagemaker_session=sagemaker_session)
5154

55+
def test_deploy_model(pytorch_training_job, sagemaker_session):
56+
endpoint_name = 'test-pytorch-deploy-model-{}'.format(sagemaker_timestamp())
57+
58+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
59+
desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=pytorch_training_job)
60+
model_data = desc['ModelArtifacts']['S3ModelArtifacts']
61+
model = PyTorchModel(model_data, 'SageMakerRole', entry_point=MNIST_SCRIPT, sagemaker_session=sagemaker_session)
62+
predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name)
63+
64+
batch_size = 100
65+
data = numpy.random.rand(batch_size, 1, 28, 28)
66+
output = predictor.predict(data)
5267

53-
def test_async_fit(sagemaker_session, pytorch_full_version):
68+
assert numpy.asarray(output).shape == (batch_size, 10)
69+
70+
71+
def test_async_fit_deploy(sagemaker_session, pytorch_full_version):
5472
training_job_name = ""
5573
# TODO: add tests against local mode when it's ready to be used
56-
instance_type = 'ml.c4.xlarge'
74+
instance_type = 'ml.p2.xlarge'
5775

5876
with timeout(minutes=10):
59-
pytorch = PyTorch(entry_point=MNIST_SCRIPT, role='SageMakerRole', framework_version=pytorch_full_version,
60-
train_instance_count=1, train_instance_type=instance_type,
61-
sagemaker_session=sagemaker_session)
77+
pytorch = _get_pytorch_estimator(sagemaker_session, pytorch_full_version, instance_type)
6278

6379
pytorch.fit({'training': _upload_training_data(pytorch)}, wait=False)
6480
training_job_name = pytorch.latest_training_job.name
@@ -67,19 +83,26 @@ def test_async_fit(sagemaker_session, pytorch_full_version):
6783
time.sleep(20)
6884

6985
if not _is_local_mode(instance_type):
70-
with timeout(minutes=35):
86+
endpoint_name = 'test-pytorch-async-fit-attach-deploy-{}'.format(sagemaker_timestamp())
87+
88+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
7189
print("Re-attaching now to: %s" % training_job_name)
72-
PyTorch.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session)
90+
estimator = PyTorch.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session)
91+
predictor = estimator.deploy(1, instance_type, endpoint_name=endpoint_name)
92+
93+
batch_size = 100
94+
data = numpy.random.rand(batch_size, 1, 28, 28)
95+
output = predictor.predict(data)
96+
97+
assert numpy.asarray(output).shape == (batch_size, 10)
7398

7499

75100
# TODO(nadiaya): Run against local mode when errors will be propagated
76101
def test_failed_training_job(sagemaker_session, pytorch_full_version):
77102
script_path = os.path.join(MNIST_DIR, 'failure_script.py')
78103

79104
with timeout(minutes=15):
80-
pytorch = PyTorch(entry_point=script_path, role='SageMakerRole', framework_version=pytorch_full_version,
81-
train_instance_count=1, train_instance_type='ml.c4.xlarge',
82-
sagemaker_session=sagemaker_session)
105+
pytorch = _get_pytorch_estimator(sagemaker_session, pytorch_full_version, entry_point=script_path)
83106

84107
with pytest.raises(ValueError) as e:
85108
pytorch.fit(_upload_training_data(pytorch))
@@ -91,5 +114,12 @@ def _upload_training_data(pytorch):
91114
key_prefix='integ-test-data/pytorch_mnist/training')
92115

93116

117+
def _get_pytorch_estimator(sagemaker_session, pytorch_full_version, instance_type='ml.c4.xlarge',
118+
entry_point=MNIST_SCRIPT):
119+
return PyTorch(entry_point=entry_point, role='SageMakerRole', framework_version=pytorch_full_version,
120+
py_version=PYTHON_VERSION, train_instance_count=1, train_instance_type=instance_type,
121+
sagemaker_session=sagemaker_session)
122+
123+
94124
def _is_local_mode(instance_type):
95125
return instance_type == 'local'

0 commit comments

Comments
 (0)