Skip to content

Commit 7ed5cc0

Browse files
committed
Add integ tests for prediction. Change tests to pick image python version based on the current environment python version.
1 parent 58b9743 commit 7ed5cc0

File tree

2 files changed

+51
-27
lines changed

2 files changed

+51
-27
lines changed

tests/data/pytorch_mnist/mnist.py

+7
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,10 @@ def test(model, test_loader, cuda):
163163
logger.debug('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
164164
test_loss, correct, len(test_loader.dataset),
165165
100. * correct / len(test_loader.dataset)))
166+
167+
168+
def model_fn(model_dir):
169+
model = torch.nn.DataParallel(Net())
170+
with open(os.path.join(model_dir, 'model'), 'rb') as f:
171+
model.load_state_dict(torch.load(f))
172+
return model

tests/integ/test_pytorch_train.py

+44-27
Original file line numberDiff line numberDiff line change
@@ -10,55 +10,62 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
import numpy
1314
import os
15+
import sys
1416
import time
1517
import pytest
1618
from sagemaker.pytorch.estimator import PyTorch
19+
from sagemaker.pytorch.model import PyTorchModel
20+
from sagemaker.utils import sagemaker_timestamp
1721
from tests.integ import DATA_DIR
18-
from tests.integ.timeout import timeout
22+
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
1923

2024
MNIST_DIR = os.path.join(DATA_DIR, 'pytorch_mnist')
2125
MNIST_SCRIPT = os.path.join(MNIST_DIR, 'mnist.py')
26+
PYTHON_VERSION = 'py' + str(sys.version_info.major)
2227

2328

2429
@pytest.fixture(scope='module', name='pytorch_training_job')
25-
def fixture_training_job(sagemaker_session, pytorch_full_version, instance_type):
30+
def fixture_training_job(sagemaker_session, pytorch_full_version):
31+
instance_type = 'ml.c4.xlarge'
2632
with timeout(minutes=15):
27-
pytorch = PyTorch(entry_point=MNIST_SCRIPT, role='SageMakerRole', framework_version=pytorch_full_version,
28-
train_instance_count=1, train_instance_type=instance_type,
29-
sagemaker_session=sagemaker_session)
33+
pytorch = _get_pytorch_estimator(sagemaker_session, pytorch_full_version, instance_type)
3034

3135
pytorch.fit({'training': _upload_training_data(pytorch)})
3236
return pytorch.latest_training_job.name
3337

3438

35-
def test_sync_fit(sagemaker_session, pytorch_full_version):
36-
training_job_name = ""
39+
def test_sync_fit_deploy(pytorch_training_job, sagemaker_session):
3740
# TODO: add tests against local mode when it's ready to be used
38-
instance_type = 'ml.p2.xlarge'
41+
endpoint_name = 'test-pytorch-sync-fit-attach-deploy{}'.format(sagemaker_timestamp())
42+
with timeout(minutes=20):
43+
estimator = PyTorch.attach(pytorch_training_job, sagemaker_session=sagemaker_session)
44+
predictor = estimator.deploy(1, 'ml.c4.xlarge', endpoint_name=endpoint_name)
45+
data = numpy.zeros(shape=(1, 1, 28, 28))
46+
predictor.predict(data)
3947

40-
with timeout(minutes=15):
41-
pytorch = PyTorch(entry_point=MNIST_SCRIPT, role='SageMakerRole', framework_version=pytorch_full_version,
42-
train_instance_count=1, train_instance_type=instance_type,
43-
sagemaker_session=sagemaker_session)
4448

45-
pytorch.fit({'training': _upload_training_data(pytorch)})
46-
training_job_name = pytorch.latest_training_job.name
49+
def test_deploy_model(pytorch_training_job, sagemaker_session):
50+
endpoint_name = 'test-pytorch-deploy-model-{}'.format(sagemaker_timestamp())
4751

48-
if not _is_local_mode(instance_type):
49-
with timeout(minutes=20):
50-
PyTorch.attach(training_job_name, sagemaker_session=sagemaker_session)
52+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
53+
desc = sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=pytorch_training_job)
54+
model_data = desc['ModelArtifacts']['S3ModelArtifacts']
55+
model = PyTorchModel(model_data, 'SageMakerRole', entry_point=MNIST_SCRIPT, sagemaker_session=sagemaker_session)
56+
predictor = model.deploy(1, 'ml.m4.xlarge', endpoint_name=endpoint_name)
5157

58+
data = numpy.zeros(shape=(1, 1, 28, 28))
59+
predictor.predict(data)
5260

53-
def test_async_fit(sagemaker_session, pytorch_full_version):
61+
62+
def test_async_fit_deploy(sagemaker_session, pytorch_full_version):
5463
training_job_name = ""
5564
# TODO: add tests against local mode when it's ready to be used
56-
instance_type = 'ml.c4.xlarge'
65+
instance_type = 'ml.p2.xlarge'
5766

5867
with timeout(minutes=10):
59-
pytorch = PyTorch(entry_point=MNIST_SCRIPT, role='SageMakerRole', framework_version=pytorch_full_version,
60-
train_instance_count=1, train_instance_type=instance_type,
61-
sagemaker_session=sagemaker_session)
68+
pytorch = _get_pytorch_estimator(sagemaker_session, pytorch_full_version, instance_type)
6269

6370
pytorch.fit({'training': _upload_training_data(pytorch)}, wait=False)
6471
training_job_name = pytorch.latest_training_job.name
@@ -67,19 +74,22 @@ def test_async_fit(sagemaker_session, pytorch_full_version):
6774
time.sleep(20)
6875

6976
if not _is_local_mode(instance_type):
70-
with timeout(minutes=35):
77+
endpoint_name = 'test-pytorch-async-fit-attach-deploy-{}'.format(sagemaker_timestamp())
78+
79+
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
7180
print("Re-attaching now to: %s" % training_job_name)
72-
PyTorch.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session)
81+
estimator = PyTorch.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session)
82+
predictor = estimator.deploy(1, instance_type, endpoint_name=endpoint_name)
83+
data = numpy.zeros(shape=(1, 1, 28, 28))
84+
predictor.predict(data)
7385

7486

7587
# TODO(nadiaya): Run against local mode when errors will be propagated
7688
def test_failed_training_job(sagemaker_session, pytorch_full_version):
7789
script_path = os.path.join(MNIST_DIR, 'failure_script.py')
7890

7991
with timeout(minutes=15):
80-
pytorch = PyTorch(entry_point=script_path, role='SageMakerRole', framework_version=pytorch_full_version,
81-
train_instance_count=1, train_instance_type='ml.c4.xlarge',
82-
sagemaker_session=sagemaker_session)
92+
pytorch = _get_pytorch_estimator(sagemaker_session, pytorch_full_version, entry_point=script_path)
8393

8494
with pytest.raises(ValueError) as e:
8595
pytorch.fit(_upload_training_data(pytorch))
@@ -91,5 +101,12 @@ def _upload_training_data(pytorch):
91101
key_prefix='integ-test-data/pytorch_mnist/training')
92102

93103

104+
def _get_pytorch_estimator(sagemaker_session, pytorch_full_version, instance_type='ml.c4.xlarge',
105+
entry_point=MNIST_SCRIPT):
106+
return PyTorch(entry_point=entry_point, role='SageMakerRole', framework_version=pytorch_full_version,
107+
py_version=PYTHON_VERSION, train_instance_count=1, train_instance_type=instance_type,
108+
sagemaker_session=sagemaker_session)
109+
110+
94111
def _is_local_mode(instance_type):
95112
return instance_type == 'local'

0 commit comments

Comments
 (0)