10
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
+ import numpy
13
14
import os
15
+ import sys
14
16
import time
15
17
import pytest
16
18
from sagemaker .pytorch .estimator import PyTorch
19
+ from sagemaker .pytorch .model import PyTorchModel
20
+ from sagemaker .utils import sagemaker_timestamp
17
21
from tests .integ import DATA_DIR
18
- from tests .integ .timeout import timeout
22
+ from tests .integ .timeout import timeout , timeout_and_delete_endpoint_by_name
19
23
20
24
MNIST_DIR = os .path .join (DATA_DIR , 'pytorch_mnist' )
21
25
MNIST_SCRIPT = os .path .join (MNIST_DIR , 'mnist.py' )
26
+ PYTHON_VERSION = 'py' + str (sys .version_info .major )
22
27
23
28
24
29
@pytest .fixture (scope = 'module' , name = 'pytorch_training_job' )
25
- def fixture_training_job (sagemaker_session , pytorch_full_version , instance_type ):
30
+ def fixture_training_job (sagemaker_session , pytorch_full_version ):
31
+ instance_type = 'ml.c4.xlarge'
26
32
with timeout (minutes = 15 ):
27
- pytorch = PyTorch (entry_point = MNIST_SCRIPT , role = 'SageMakerRole' , framework_version = pytorch_full_version ,
28
- train_instance_count = 1 , train_instance_type = instance_type ,
29
- sagemaker_session = sagemaker_session )
33
+ pytorch = _get_pytorch_estimator (sagemaker_session , pytorch_full_version , instance_type )
30
34
31
35
pytorch .fit ({'training' : _upload_training_data (pytorch )})
32
36
return pytorch .latest_training_job .name
33
37
34
38
35
- def test_sync_fit (sagemaker_session , pytorch_full_version ):
36
- training_job_name = ""
39
+ def test_sync_fit_deploy (pytorch_training_job , sagemaker_session ):
37
40
# TODO: add tests against local mode when it's ready to be used
38
- instance_type = 'ml.p2.xlarge'
41
+ endpoint_name = 'test-pytorch-sync-fit-attach-deploy{}' .format (sagemaker_timestamp ())
42
+ with timeout (minutes = 20 ):
43
+ estimator = PyTorch .attach (pytorch_training_job , sagemaker_session = sagemaker_session )
44
+ predictor = estimator .deploy (1 , 'ml.c4.xlarge' , endpoint_name = endpoint_name )
45
+ data = numpy .zeros (shape = (1 , 1 , 28 , 28 ))
46
+ predictor .predict (data )
39
47
40
- with timeout (minutes = 15 ):
41
- pytorch = PyTorch (entry_point = MNIST_SCRIPT , role = 'SageMakerRole' , framework_version = pytorch_full_version ,
42
- train_instance_count = 1 , train_instance_type = instance_type ,
43
- sagemaker_session = sagemaker_session )
44
48
45
- pytorch . fit ({ 'training' : _upload_training_data ( pytorch )})
46
- training_job_name = pytorch . latest_training_job . name
49
+ def test_deploy_model ( pytorch_training_job , sagemaker_session ):
50
+ endpoint_name = 'test- pytorch-deploy-model-{}' . format ( sagemaker_timestamp ())
47
51
48
- if not _is_local_mode (instance_type ):
49
- with timeout (minutes = 20 ):
50
- PyTorch .attach (training_job_name , sagemaker_session = sagemaker_session )
52
+ with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
53
+ desc = sagemaker_session .sagemaker_client .describe_training_job (TrainingJobName = pytorch_training_job )
54
+ model_data = desc ['ModelArtifacts' ]['S3ModelArtifacts' ]
55
+ model = PyTorchModel (model_data , 'SageMakerRole' , entry_point = MNIST_SCRIPT , sagemaker_session = sagemaker_session )
56
+ predictor = model .deploy (1 , 'ml.m4.xlarge' , endpoint_name = endpoint_name )
51
57
58
+ data = numpy .zeros (shape = (1 , 1 , 28 , 28 ))
59
+ predictor .predict (data )
52
60
53
- def test_async_fit (sagemaker_session , pytorch_full_version ):
61
+
62
+ def test_async_fit_deploy (sagemaker_session , pytorch_full_version ):
54
63
training_job_name = ""
55
64
# TODO: add tests against local mode when it's ready to be used
56
- instance_type = 'ml.c4 .xlarge'
65
+ instance_type = 'ml.p2 .xlarge'
57
66
58
67
with timeout (minutes = 10 ):
59
- pytorch = PyTorch (entry_point = MNIST_SCRIPT , role = 'SageMakerRole' , framework_version = pytorch_full_version ,
60
- train_instance_count = 1 , train_instance_type = instance_type ,
61
- sagemaker_session = sagemaker_session )
68
+ pytorch = _get_pytorch_estimator (sagemaker_session , pytorch_full_version , instance_type )
62
69
63
70
pytorch .fit ({'training' : _upload_training_data (pytorch )}, wait = False )
64
71
training_job_name = pytorch .latest_training_job .name
@@ -67,19 +74,22 @@ def test_async_fit(sagemaker_session, pytorch_full_version):
67
74
time .sleep (20 )
68
75
69
76
if not _is_local_mode (instance_type ):
70
- with timeout (minutes = 35 ):
77
+ endpoint_name = 'test-pytorch-async-fit-attach-deploy-{}' .format (sagemaker_timestamp ())
78
+
79
+ with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
71
80
print ("Re-attaching now to: %s" % training_job_name )
72
- PyTorch .attach (training_job_name = training_job_name , sagemaker_session = sagemaker_session )
81
+ estimator = PyTorch .attach (training_job_name = training_job_name , sagemaker_session = sagemaker_session )
82
+ predictor = estimator .deploy (1 , instance_type , endpoint_name = endpoint_name )
83
+ data = numpy .zeros (shape = (1 , 1 , 28 , 28 ))
84
+ predictor .predict (data )
73
85
74
86
75
87
# TODO(nadiaya): Run against local mode when errors will be propagated
76
88
def test_failed_training_job (sagemaker_session , pytorch_full_version ):
77
89
script_path = os .path .join (MNIST_DIR , 'failure_script.py' )
78
90
79
91
with timeout (minutes = 15 ):
80
- pytorch = PyTorch (entry_point = script_path , role = 'SageMakerRole' , framework_version = pytorch_full_version ,
81
- train_instance_count = 1 , train_instance_type = 'ml.c4.xlarge' ,
82
- sagemaker_session = sagemaker_session )
92
+ pytorch = _get_pytorch_estimator (sagemaker_session , pytorch_full_version , entry_point = script_path )
83
93
84
94
with pytest .raises (ValueError ) as e :
85
95
pytorch .fit (_upload_training_data (pytorch ))
@@ -91,5 +101,12 @@ def _upload_training_data(pytorch):
91
101
key_prefix = 'integ-test-data/pytorch_mnist/training' )
92
102
93
103
104
+ def _get_pytorch_estimator (sagemaker_session , pytorch_full_version , instance_type = 'ml.c4.xlarge' ,
105
+ entry_point = MNIST_SCRIPT ):
106
+ return PyTorch (entry_point = entry_point , role = 'SageMakerRole' , framework_version = pytorch_full_version ,
107
+ py_version = PYTHON_VERSION , train_instance_count = 1 , train_instance_type = instance_type ,
108
+ sagemaker_session = sagemaker_session )
109
+
110
+
94
111
def _is_local_mode (instance_type ):
95
112
return instance_type == 'local'
0 commit comments