10
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
+ import numpy
13
14
import os
15
+ import sys
14
16
import time
15
17
import pytest
16
18
from sagemaker .pytorch .estimator import PyTorch
19
+ from sagemaker .pytorch .model import PyTorchModel
20
+ from sagemaker .utils import sagemaker_timestamp
17
21
from tests .integ import DATA_DIR
18
- from tests .integ .timeout import timeout
22
+ from tests .integ .timeout import timeout , timeout_and_delete_endpoint_by_name
19
23
20
24
MNIST_DIR = os .path .join (DATA_DIR , 'pytorch_mnist' )
21
25
MNIST_SCRIPT = os .path .join (MNIST_DIR , 'mnist.py' )
26
+ PYTHON_VERSION = 'py' + str (sys .version_info .major )
22
27
23
28
24
29
@pytest .fixture (scope = 'module' , name = 'pytorch_training_job' )
25
- def fixture_training_job (sagemaker_session , pytorch_full_version , instance_type ):
30
+ def fixture_training_job (sagemaker_session , pytorch_full_version ):
31
+ instance_type = 'ml.c4.xlarge'
26
32
with timeout (minutes = 15 ):
27
- pytorch = PyTorch (entry_point = MNIST_SCRIPT , role = 'SageMakerRole' , framework_version = pytorch_full_version ,
28
- train_instance_count = 1 , train_instance_type = instance_type ,
29
- sagemaker_session = sagemaker_session )
33
+ pytorch = _get_pytorch_estimator (sagemaker_session , pytorch_full_version , instance_type )
30
34
31
35
pytorch .fit ({'training' : _upload_training_data (pytorch )})
32
36
return pytorch .latest_training_job .name
33
37
34
38
35
- def test_sync_fit (sagemaker_session , pytorch_full_version ):
36
- training_job_name = ""
39
+ def test_sync_fit_deploy (pytorch_training_job , sagemaker_session ):
37
40
# TODO: add tests against local mode when it's ready to be used
38
- instance_type = 'ml.p2.xlarge'
41
+ endpoint_name = 'test-pytorch-sync-fit-attach-deploy{}' .format (sagemaker_timestamp ())
42
+ with timeout (minutes = 20 ):
43
+ estimator = PyTorch .attach (pytorch_training_job , sagemaker_session = sagemaker_session )
44
+ predictor = estimator .deploy (1 , 'ml.c4.xlarge' , endpoint_name = endpoint_name )
45
+ data = numpy .zeros (shape = (1 , 1 , 28 , 28 ))
46
+ predictor .predict (data )
39
47
40
- with timeout (minutes = 15 ):
41
- pytorch = PyTorch (entry_point = MNIST_SCRIPT , role = 'SageMakerRole' , framework_version = pytorch_full_version ,
42
- train_instance_count = 1 , train_instance_type = instance_type ,
43
- sagemaker_session = sagemaker_session )
48
+ batch_size = 100
49
+ data = numpy .random .rand (batch_size , 1 , 28 , 28 )
50
+ output = predictor .predict (data )
44
51
45
- pytorch .fit ({'training' : _upload_training_data (pytorch )})
46
- training_job_name = pytorch .latest_training_job .name
52
+ assert numpy .asarray (output ).shape == (batch_size , 10 )
47
53
48
- if not _is_local_mode (instance_type ):
49
- with timeout (minutes = 20 ):
50
- PyTorch .attach (training_job_name , sagemaker_session = sagemaker_session )
51
54
55
+ def test_deploy_model (pytorch_training_job , sagemaker_session ):
56
+ endpoint_name = 'test-pytorch-deploy-model-{}' .format (sagemaker_timestamp ())
57
+
58
+ with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
59
+ desc = sagemaker_session .sagemaker_client .describe_training_job (TrainingJobName = pytorch_training_job )
60
+ model_data = desc ['ModelArtifacts' ]['S3ModelArtifacts' ]
61
+ model = PyTorchModel (model_data , 'SageMakerRole' , entry_point = MNIST_SCRIPT , sagemaker_session = sagemaker_session )
62
+ predictor = model .deploy (1 , 'ml.m4.xlarge' , endpoint_name = endpoint_name )
63
+
64
+ batch_size = 100
65
+ data = numpy .random .rand (batch_size , 1 , 28 , 28 )
66
+ output = predictor .predict (data )
52
67
53
- def test_async_fit (sagemaker_session , pytorch_full_version ):
68
+ assert numpy .asarray (output ).shape == (batch_size , 10 )
69
+
70
+
71
+ def test_async_fit_deploy (sagemaker_session , pytorch_full_version ):
54
72
training_job_name = ""
55
73
# TODO: add tests against local mode when it's ready to be used
56
- instance_type = 'ml.c4 .xlarge'
74
+ instance_type = 'ml.p2 .xlarge'
57
75
58
76
with timeout (minutes = 10 ):
59
- pytorch = PyTorch (entry_point = MNIST_SCRIPT , role = 'SageMakerRole' , framework_version = pytorch_full_version ,
60
- train_instance_count = 1 , train_instance_type = instance_type ,
61
- sagemaker_session = sagemaker_session )
77
+ pytorch = _get_pytorch_estimator (sagemaker_session , pytorch_full_version , instance_type )
62
78
63
79
pytorch .fit ({'training' : _upload_training_data (pytorch )}, wait = False )
64
80
training_job_name = pytorch .latest_training_job .name
@@ -67,19 +83,26 @@ def test_async_fit(sagemaker_session, pytorch_full_version):
67
83
time .sleep (20 )
68
84
69
85
if not _is_local_mode (instance_type ):
70
- with timeout (minutes = 35 ):
86
+ endpoint_name = 'test-pytorch-async-fit-attach-deploy-{}' .format (sagemaker_timestamp ())
87
+
88
+ with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
71
89
print ("Re-attaching now to: %s" % training_job_name )
72
- PyTorch .attach (training_job_name = training_job_name , sagemaker_session = sagemaker_session )
90
+ estimator = PyTorch .attach (training_job_name = training_job_name , sagemaker_session = sagemaker_session )
91
+ predictor = estimator .deploy (1 , instance_type , endpoint_name = endpoint_name )
92
+
93
+ batch_size = 100
94
+ data = numpy .random .rand (batch_size , 1 , 28 , 28 )
95
+ output = predictor .predict (data )
96
+
97
+ assert numpy .asarray (output ).shape == (batch_size , 10 )
73
98
74
99
75
100
# TODO(nadiaya): Run against local mode when errors will be propagated
76
101
def test_failed_training_job (sagemaker_session , pytorch_full_version ):
77
102
script_path = os .path .join (MNIST_DIR , 'failure_script.py' )
78
103
79
104
with timeout (minutes = 15 ):
80
- pytorch = PyTorch (entry_point = script_path , role = 'SageMakerRole' , framework_version = pytorch_full_version ,
81
- train_instance_count = 1 , train_instance_type = 'ml.c4.xlarge' ,
82
- sagemaker_session = sagemaker_session )
105
+ pytorch = _get_pytorch_estimator (sagemaker_session , pytorch_full_version , entry_point = script_path )
83
106
84
107
with pytest .raises (ValueError ) as e :
85
108
pytorch .fit (_upload_training_data (pytorch ))
@@ -91,5 +114,12 @@ def _upload_training_data(pytorch):
91
114
key_prefix = 'integ-test-data/pytorch_mnist/training' )
92
115
93
116
117
+ def _get_pytorch_estimator (sagemaker_session , pytorch_full_version , instance_type = 'ml.c4.xlarge' ,
118
+ entry_point = MNIST_SCRIPT ):
119
+ return PyTorch (entry_point = entry_point , role = 'SageMakerRole' , framework_version = pytorch_full_version ,
120
+ py_version = PYTHON_VERSION , train_instance_count = 1 , train_instance_type = instance_type ,
121
+ sagemaker_session = sagemaker_session )
122
+
123
+
94
124
def _is_local_mode (instance_type ):
95
125
return instance_type == 'local'
0 commit comments